
Source code for torchtext.datasets.nli

from .. import data

class ShiftReduceField(data.Field):

    def __init__(self):

        super(ShiftReduceField, self).__init__(preprocessing=lambda parse: [
            'reduce' if t == ')' else 'shift' for t in parse if t != '('])

        self.build_vocab([['reduce'], ['shift']])

class ParsedTextField(data.Field):
        Field for parsed sentences data in NLI datasets.
        Expensive tokenization could be omitted from the pipeline as
        the parse tree annotations are already in tokenized form.

    def __init__(self, eos_token='<pad>', lower=False, reverse=False):
        if reverse:
            super(ParsedTextField, self).__init__(
                eos_token=eos_token, lower=lower,
                preprocessing=lambda parse: [t for t in parse if t not in ('(', ')')],
                postprocessing=lambda parse, _: [list(reversed(p)) for p in parse],
            super(ParsedTextField, self).__init__(
                eos_token=eos_token, lower=lower,
                preprocessing=lambda parse: [t for t in parse if t not in ('(', ')')],

class NLIDataset(data.TabularDataset):

    urls = []
    dirname = ''
    name = 'nli'

    def sort_key(ex):
        return data.interleave_keys(
            len(ex.premise), len(ex.hypothesis))

    def splits(cls, text_field, label_field, parse_field=None,
               extra_fields={}, root='.data', train='train.jsonl',
               validation='val.jsonl', test='test.jsonl'):
        """Create dataset objects for splits of the SNLI dataset.

        This is the most flexible way to use the dataset.

            text_field: The field that will be used for premise and hypothesis
            label_field: The field that will be used for label data.
            parse_field: The field that will be used for shift-reduce parser
                transitions, or None to not include them.
            extra_fields: A dict[json_key: Tuple(field_name, Field)]
            root: The root directory that the dataset's zip archive will be
                expanded into.
            train: The filename of the train data. Default: 'train.jsonl'.
            validation: The filename of the validation data, or None to not
                load the validation set. Default: 'dev.jsonl'.
            test: The filename of the test data, or None to not load the test
                set. Default: 'test.jsonl'.
        path =

        if parse_field is None:
            fields = {'sentence1': ('premise', text_field),
                      'sentence2': ('hypothesis', text_field),
                      'gold_label': ('label', label_field)}
            fields = {'sentence1_binary_parse': [('premise', text_field),
                                                 ('premise_transitions', parse_field)],
                      'sentence2_binary_parse': [('hypothesis', text_field),
                                                 ('hypothesis_transitions', parse_field)],
                      'gold_label': ('label', label_field)}

        for key in extra_fields:
            if key not in fields.keys():
                fields[key] = extra_fields[key]

        return super(NLIDataset, cls).splits(
            path, root, train, validation, test,
            format='json', fields=fields,
            filter_pred=lambda ex: ex.label != '-')

    def iters(cls, batch_size=32, device=0, root='.data',
              vectors=None, trees=False, **kwargs):
        """Create iterator objects for splits of the SNLI dataset.

        This is the simplest way to use the dataset, and assumes common
        defaults for field, vocabulary, and iterator parameters.

            batch_size: Batch size.
            device: Device to create batches on. Use -1 for CPU and None for
                the currently active GPU device.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose wikitext-2
                subdirectory the data files will be stored.
            vectors: one of the available pretrained vectors or a list with each
                element one of the available pretrained vectors (see Vocab.load_vectors)
            trees: Whether to include shift-reduce parser transitions.
                Default: False.
            Remaining keyword arguments: Passed to the splits method.
        if trees:
            TEXT = ParsedTextField()
            TRANSITIONS = ShiftReduceField()
            TEXT = data.Field(tokenize='spacy')
            TRANSITIONS = None
        LABEL = data.Field(sequential=False)

        train, val, test = cls.splits(
            TEXT, LABEL, TRANSITIONS, root=root, **kwargs)

        TEXT.build_vocab(train, vectors=vectors)

        return data.BucketIterator.splits(
            (train, val, test), batch_size=batch_size, device=device)

[docs]class SNLI(NLIDataset): urls = [''] dirname = 'snli_1.0' name = 'snli'
[docs] @classmethod def splits(cls, text_field, label_field, parse_field=None, root='.data', train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl', test='snli_1.0_test.jsonl'): return super(SNLI, cls).splits(text_field, label_field, parse_field=parse_field, root=root, train=train, validation=validation, test=test)
[docs]class MultiNLI(NLIDataset): urls = [''] dirname = 'multinli_1.0' name = 'multinli'
[docs] @classmethod def splits(cls, text_field, label_field, parse_field=None, genre_field=None, root='.data', train='multinli_1.0_train.jsonl', validation='multinli_1.0_dev_matched.jsonl', test='multinli_1.0_dev_mismatched.jsonl'): extra_fields = {} if genre_field is not None: extra_fields["genre"] = ("genre", genre_field) return super(MultiNLI, cls).splits(text_field, label_field, parse_field=parse_field, extra_fields=extra_fields, root=root, train=train, validation=validation, test=test)
class XNLI(NLIDataset): urls = [''] dirname = 'XNLI-1.0' name = 'xnli' @classmethod def splits(cls, text_field, label_field, genre_field=None, language_field=None, root='.data', validation='', test='xnli.test.jsonl'): extra_fields = {} if genre_field is not None: extra_fields["genre"] = ("genre", genre_field) if language_field is not None: extra_fields["language"] = ("language", language_field) return super(XNLI, cls).splits(text_field, label_field, extra_fields=extra_fields, root=root, train=None, validation=validation, test=test) @classmethod def iters(cls, *args, **kwargs): raise NotImplementedError('XNLI dataset does not support iters')
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources