Shortcuts

Source code for torchtext.experimental.datasets.text_classification

import logging
import torch
import io
from torchtext.utils import download_from_url, extract_archive
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.vocab import Vocab
from torchtext.datasets import TextClassificationDataset

URLS = {
    'IMDB':
        'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
}


def _create_data_from_iterator(vocab, iterator, removed_tokens):
    for cls, tokens in iterator:
        yield cls, iter(map(lambda x: vocab[x],
                        filter(lambda x: x not in removed_tokens, tokens)))


def _imdb_iterator(key, extracted_files, tokenizer, ngrams, yield_cls=False):
    for fname in extracted_files:
        if 'urls' in fname:
            continue
        elif key in fname and ('pos' in fname or 'neg' in fname):
            with io.open(fname, encoding="utf8") as f:
                label = 1 if 'pos' in fname else 0
                if yield_cls:
                    yield label, ngrams_iterator(tokenizer(f.read()), ngrams)
                else:
                    yield ngrams_iterator(tokenizer(f.read()), ngrams)


def _generate_data_iterators(dataset_name, root, ngrams, tokenizer, data_select):
    if not tokenizer:
        tokenizer = get_tokenizer("basic_english")

    if not set(data_select).issubset(set(('train', 'test'))):
        raise TypeError('Given data selection {} is not supported!'.format(data_select))

    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    iters_group = {}
    if 'train' in data_select:
        iters_group['vocab'] = _imdb_iterator('train', extracted_files,
                                              tokenizer, ngrams)
    for item in data_select:
        iters_group[item] = _imdb_iterator(item, extracted_files,
                                           tokenizer, ngrams, yield_cls=True)
    return iters_group


def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None,
                    removed_tokens=[], tokenizer=None,
                    data_select=('train', 'test')):

    if isinstance(data_select, str):
        data_select = [data_select]

    iters_group = _generate_data_iterators(dataset_name, root, ngrams,
                                           tokenizer, data_select)

    if vocab is None:
        if 'vocab' not in iters_group.keys():
            raise TypeError("Must pass a vocab if train is not selected.")
        logging.info('Building Vocab based on train data')
        vocab = build_vocab_from_iterator(iters_group['vocab'])
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))

    data = {}
    for item in data_select:
        data[item] = {}
        data[item]['data'] = []
        data[item]['labels'] = []
        logging.info('Creating {} data'.format(item))
        data_iter = _create_data_from_iterator(vocab, iters_group[item], removed_tokens)
        for cls, tokens in data_iter:
            data[item]['data'].append((torch.tensor(cls),
                                       torch.tensor([token_id for token_id in tokens])))
            data[item]['labels'].append(cls)
        data[item]['labels'] = set(data[item]['labels'])

    return tuple(TextClassificationDataset(vocab, data[item]['data'],
                                           data[item]['labels']) for item in data_select)


[docs]def IMDB(*args, **kwargs): """ Defines IMDB datasets. The labels includes: - 0 : Negative - 1 : Positive Create sentiment analysis dataset: IMDB Separately returns the training and test dataset Arguments: root: Directory where the datasets are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. removed_tokens: removed tokens from output dataset (Default: []) tokenizer: the tokenizer used to preprocess raw text data. The default one is basic_english tokenizer in fastText. spacy tokenizer is supported as well. A custom tokenizer is callable function with input of a string and output of a token list. data_select: a string or tuple for the returned datasets (Default: ('train', 'test')) By default, all the three datasets (train, test, valid) are generated. Users could also choose any one or two of them, for example ('train', 'test') or just a string 'train'. If 'train' is not in the tuple or string, a vocab object should be provided which will be used to process valid and/or test data. Examples: >>> from torchtext.experimental.datasets import IMDB >>> from torchtext.data.utils import get_tokenizer >>> train, test = IMDB(ngrams=3) >>> tokenizer = get_tokenizer("spacy") >>> train, test = IMDB(tokenizer=tokenizer) >>> train, = IMDB(tokenizer=tokenizer, data_select='train') """ return _setup_datasets(*(("IMDB",) + args), **kwargs)
DATASETS = { 'IMDB': IMDB } LABELS = { 'IMDB': {0: 'Negative', 1: 'Positive'} }
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources