Shortcuts

Source code for torchtext.datasets.nli

from .. import data


class ShiftReduceField(data.Field):

    def __init__(self):

        super(ShiftReduceField, self).__init__(preprocessing=lambda parse: [
            'reduce' if t == ')' else 'shift' for t in parse if t != '('])

        self.build_vocab([['reduce'], ['shift']])


class ParsedTextField(data.Field):
    """
        Field for parsed sentences data in NLI datasets.
        Expensive tokenization could be omitted from the pipeline as
        the parse tree annotations are already in tokenized form.
    """

    def __init__(self, eos_token='<pad>', lower=False, reverse=False):
        if reverse:
            super(ParsedTextField, self).__init__(
                eos_token=eos_token, lower=lower,
                preprocessing=lambda parse: [t for t in parse if t not in ('(', ')')],
                postprocessing=lambda parse, _: [list(reversed(p)) for p in parse],
                include_lengths=True)
        else:
            super(ParsedTextField, self).__init__(
                eos_token=eos_token, lower=lower,
                preprocessing=lambda parse: [t for t in parse if t not in ('(', ')')],
                include_lengths=True)


class NLIDataset(data.TabularDataset):

    urls = []
    dirname = ''
    name = 'nli'

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(
            len(ex.premise), len(ex.hypothesis))

    @classmethod
    def splits(cls, text_field, label_field, parse_field=None,
               extra_fields={}, root='.data', train='train.jsonl',
               validation='val.jsonl', test='test.jsonl'):
        """Create dataset objects for splits of the SNLI dataset.

        This is the most flexible way to use the dataset.

        Arguments:
            text_field: The field that will be used for premise and hypothesis
                data.
            label_field: The field that will be used for label data.
            parse_field: The field that will be used for shift-reduce parser
                transitions, or None to not include them.
            extra_fields: A dict[json_key: Tuple(field_name, Field)]
            root: The root directory that the dataset's zip archive will be
                expanded into.
            train: The filename of the train data. Default: 'train.jsonl'.
            validation: The filename of the validation data, or None to not
                load the validation set. Default: 'dev.jsonl'.
            test: The filename of the test data, or None to not load the test
                set. Default: 'test.jsonl'.
        """
        path = cls.download(root)

        if parse_field is None:
            fields = {'sentence1': ('premise', text_field),
                      'sentence2': ('hypothesis', text_field),
                      'gold_label': ('label', label_field)}
        else:
            fields = {'sentence1_binary_parse': [('premise', text_field),
                                                 ('premise_transitions', parse_field)],
                      'sentence2_binary_parse': [('hypothesis', text_field),
                                                 ('hypothesis_transitions', parse_field)],
                      'gold_label': ('label', label_field)}

        for key in extra_fields:
            if key not in fields.keys():
                fields[key] = extra_fields[key]

        return super(NLIDataset, cls).splits(
            path, root, train, validation, test,
            format='json', fields=fields,
            filter_pred=lambda ex: ex.label != '-')

    @classmethod
    def iters(cls, batch_size=32, device=0, root='.data',
              vectors=None, trees=False, **kwargs):
        """Create iterator objects for splits of the SNLI dataset.

        This is the simplest way to use the dataset, and assumes common
        defaults for field, vocabulary, and iterator parameters.

        Arguments:
            batch_size: Batch size.
            device: Device to create batches on. Use -1 for CPU and None for
                the currently active GPU device.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose wikitext-2
                subdirectory the data files will be stored.
            vectors: one of the available pretrained vectors or a list with each
                element one of the available pretrained vectors (see Vocab.load_vectors)
            trees: Whether to include shift-reduce parser transitions.
                Default: False.
            Remaining keyword arguments: Passed to the splits method.
        """
        if trees:
            TEXT = ParsedTextField()
            TRANSITIONS = ShiftReduceField()
        else:
            TEXT = data.Field(tokenize='spacy')
            TRANSITIONS = None
        LABEL = data.Field(sequential=False)

        train, val, test = cls.splits(
            TEXT, LABEL, TRANSITIONS, root=root, **kwargs)

        TEXT.build_vocab(train, vectors=vectors)
        LABEL.build_vocab(train)

        return data.BucketIterator.splits(
            (train, val, test), batch_size=batch_size, device=device)


[docs]class SNLI(NLIDataset): urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip'] dirname = 'snli_1.0' name = 'snli'
[docs] @classmethod def splits(cls, text_field, label_field, parse_field=None, root='.data', train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl', test='snli_1.0_test.jsonl'): return super(SNLI, cls).splits(text_field, label_field, parse_field=parse_field, root=root, train=train, validation=validation, test=test)
[docs]class MultiNLI(NLIDataset): urls = ['http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip'] dirname = 'multinli_1.0' name = 'multinli'
[docs] @classmethod def splits(cls, text_field, label_field, parse_field=None, genre_field=None, root='.data', train='multinli_1.0_train.jsonl', validation='multinli_1.0_dev_matched.jsonl', test='multinli_1.0_dev_mismatched.jsonl'): extra_fields = {} if genre_field is not None: extra_fields["genre"] = ("genre", genre_field) return super(MultiNLI, cls).splits(text_field, label_field, parse_field=parse_field, extra_fields=extra_fields, root=root, train=train, validation=validation, test=test)
class XNLI(NLIDataset): urls = ['http://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip'] dirname = 'XNLI-1.0' name = 'xnli' @classmethod def splits(cls, text_field, label_field, genre_field=None, language_field=None, root='.data', validation='xnli.dev.jsonl', test='xnli.test.jsonl'): extra_fields = {} if genre_field is not None: extra_fields["genre"] = ("genre", genre_field) if language_field is not None: extra_fields["language"] = ("language", language_field) return super(XNLI, cls).splits(text_field, label_field, extra_fields=extra_fields, root=root, train=None, validation=validation, test=test) @classmethod def iters(cls, *args, **kwargs): raise NotImplementedError('XNLI dataset does not support iters')
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources