Shortcuts

Source code for torchtext.datasets.text_classification

import logging
import torch
import io
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.vocab import Vocab
from tqdm import tqdm

URLS = {
    'AG_NEWS':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUDNpeUdjb0wxRms',
    'SogouNews':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbUkVqNEszd0pHaFE',
    'DBpedia':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k',
    'YelpReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbNUpYQ2N3SGlFaDg',
    'YelpReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0',
    'YahooAnswers':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU',
    'AmazonReviewPolarity':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM',
    'AmazonReviewFull':
        'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbZVhsUnRWRDhETzA'
}


def _csv_iterator(data_path, ngrams, yield_cls=False):
    tokenizer = get_tokenizer("basic_english")
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            tokens = ' '.join(row[1:])
            tokens = tokenizer(tokens)
            if yield_cls:
                yield int(row[0]) - 1, ngrams_iterator(tokens, ngrams)
            else:
                yield ngrams_iterator(tokens, ngrams)


def _create_data_from_iterator(vocab, iterator, include_unk):
    data = []
    labels = []
    with tqdm(unit_scale=0, unit='lines') as t:
        for cls, tokens in iterator:
            if include_unk:
                tokens = torch.tensor([vocab[token] for token in tokens])
            else:
                token_ids = list(filter(lambda x: x is not Vocab.UNK, [vocab[token]
                                        for token in tokens]))
                tokens = torch.tensor(token_ids)
            if len(tokens) == 0:
                logging.info('Row contains no tokens.')
            data.append((cls, tokens))
            labels.append(cls)
            t.update(1)
    return data, set(labels)


[docs]class TextClassificationDataset(torch.utils.data.Dataset): """Defines an abstract text classification datasets. Currently, we only support the following datasets: - AG_NEWS - SogouNews - DBpedia - YelpReviewPolarity - YelpReviewFull - YahooAnswers - AmazonReviewPolarity - AmazonReviewFull """
[docs] def __init__(self, vocab, data, labels): """Initiate text-classification dataset. Arguments: vocab: Vocabulary object used for dataset. data: a list of label/tokens tuple. tokens are a tensor after numericalizing the string tokens. label is an integer. [(label1, tokens1), (label2, tokens2), (label2, tokens3)] label: a set of the labels. {label1, label2} Examples: See the examples in examples/text_classification/ """ super(TextClassificationDataset, self).__init__() self._data = data self._labels = labels self._vocab = vocab
def __getitem__(self, i): return self._data[i] def __len__(self): return len(self._data) def __iter__(self): for x in self._data: yield x def get_labels(self): return self._labels def get_vocab(self): return self._vocab
def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False): dataset_tar = download_from_url(URLS[dataset_name], root=root) extracted_files = extract_archive(dataset_tar) for fname in extracted_files: if fname.endswith('train.csv'): train_csv_path = fname if fname.endswith('test.csv'): test_csv_path = fname if vocab is None: logging.info('Building Vocab based on {}'.format(train_csv_path)) vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams)) else: if not isinstance(vocab, Vocab): raise TypeError("Passed vocabulary is not of type Vocab") logging.info('Vocab has {} entries'.format(len(vocab))) logging.info('Creating training data') train_data, train_labels = _create_data_from_iterator( vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk) logging.info('Creating testing data') test_data, test_labels = _create_data_from_iterator( vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk) if len(train_labels ^ test_labels) > 0: raise ValueError("Training and test labels don't match") return (TextClassificationDataset(vocab, train_data, train_labels), TextClassificationDataset(vocab, test_data, test_labels))
[docs]def AG_NEWS(*args, **kwargs): """ Defines AG_NEWS datasets. The labels includes: - 1 : World - 2 : Sports - 3 : Business - 4 : Sci/Tech Create supervised learning dataset: AG_NEWS Separately returns the training and test dataset Arguments: root: Directory where the datasets are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. include_unk: include unknown token in the data (Default: False) Examples: >>> train_dataset, test_dataset = torchtext.datasets.AG_NEWS(ngrams=3) """ return _setup_datasets(*(("AG_NEWS",) + args), **kwargs)
[docs]def SogouNews(*args, **kwargs): """ Defines SogouNews datasets. The labels includes: - 1 : Sports - 2 : Finance - 3 : Entertainment - 4 : Automobile - 5 : Technology Create supervised learning dataset: SogouNews Separately returns the training and test dataset Arguments: root: Directory where the datasets are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. include_unk: include unknown token in the data (Default: False) Examples: >>> train_dataset, test_dataset = torchtext.datasets.SogouNews(ngrams=3) """ return _setup_datasets(*(("SogouNews",) + args), **kwargs)
[docs]def DBpedia(*args, **kwargs): """ Defines DBpedia datasets. The labels includes: - 1 : Company - 2 : EducationalInstitution - 3 : Artist - 4 : Athlete - 5 : OfficeHolder - 6 : MeanOfTransportation - 7 : Building - 8 : NaturalPlace - 9 : Village - 10 : Animal - 11 : Plant - 12 : Album - 13 : Film - 14 : WrittenWork Create supervised learning dataset: DBpedia Separately returns the training and test dataset Arguments: root: Directory where the datasets are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. include_unk: include unknown token in the data (Default: False) Examples: >>> train_dataset, test_dataset = torchtext.datasets.DBpedia(ngrams=3) """ return _setup_datasets(*(("DBpedia",) + args), **kwargs)
[docs]def YelpReviewPolarity(*args, **kwargs): """ Defines YelpReviewPolarity datasets. The labels includes: - 1 : Negative polarity. - 2 : Positive polarity. Create supervised learning dataset: YelpReviewPolarity Separately returns the training and test dataset Arguments: root: Directory where the datasets are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. include_unk: include unknown token in the data (Default: False) Examples: >>> train_dataset, test_dataset = torchtext.datasets.YelpReviewPolarity(ngrams=3) """ return _setup_datasets(*(("YelpReviewPolarity",) + args), **kwargs)
[docs]def YelpReviewFull(*args, **kwargs): """ Defines YelpReviewFull datasets. The labels includes: 1 - 5 : rating classes (5 is highly recommended). Create supervised learning dataset: YelpReviewFull Separately returns the training and test dataset Arguments: root: Directory where the datasets are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. include_unk: include unknown token in the data (Default: False) Examples: >>> train_dataset, test_dataset = torchtext.datasets.YelpReviewFull(ngrams=3) """ return _setup_datasets(*(("YelpReviewFull",) + args), **kwargs)
[docs]def YahooAnswers(*args, **kwargs): """ Defines YahooAnswers datasets. The labels includes: - 1 : Society & Culture - 2 : Science & Mathematics - 3 : Health - 4 : Education & Reference - 5 : Computers & Internet - 6 : Sports - 7 : Business & Finance - 8 : Entertainment & Music - 9 : Family & Relationships - 10 : Politics & Government Create supervised learning dataset: YahooAnswers Separately returns the training and test dataset Arguments: root: Directory where the datasets are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. include_unk: include unknown token in the data (Default: False) Examples: >>> train_dataset, test_dataset = torchtext.datasets.YahooAnswers(ngrams=3) """ return _setup_datasets(*(("YahooAnswers",) + args), **kwargs)
[docs]def AmazonReviewPolarity(*args, **kwargs): """ Defines AmazonReviewPolarity datasets. The labels includes: - 1 : Negative polarity - 2 : Positive polarity Create supervised learning dataset: AmazonReviewPolarity Separately returns the training and test dataset Arguments: root: Directory where the datasets are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. include_unk: include unknown token in the data (Default: False) Examples: >>> train_dataset, test_dataset = torchtext.datasets.AmazonReviewPolarity(ngrams=3) """ return _setup_datasets(*(("AmazonReviewPolarity",) + args), **kwargs)
[docs]def AmazonReviewFull(*args, **kwargs): """ Defines AmazonReviewFull datasets. The labels includes: 1 - 5 : rating classes (5 is highly recommended) Create supervised learning dataset: AmazonReviewFull Separately returns the training and test dataset Arguments: root: Directory where the dataset are saved. Default: ".data" ngrams: a contiguous sequence of n items from s string text. Default: 1 vocab: Vocabulary used for dataset. If None, it will generate a new vocabulary based on the train data set. include_unk: include unknown token in the data (Default: False) Examples: >>> train_dataset, test_dataset = torchtext.datasets.AmazonReviewFull(ngrams=3) """ return _setup_datasets(*(("AmazonReviewFull",) + args), **kwargs)
DATASETS = { 'AG_NEWS': AG_NEWS, 'SogouNews': SogouNews, 'DBpedia': DBpedia, 'YelpReviewPolarity': YelpReviewPolarity, 'YelpReviewFull': YelpReviewFull, 'YahooAnswers': YahooAnswers, 'AmazonReviewPolarity': AmazonReviewPolarity, 'AmazonReviewFull': AmazonReviewFull } LABELS = { 'AG_NEWS': {1: 'World', 2: 'Sports', 3: 'Business', 4: 'Sci/Tech'}, 'SogouNews': {1: 'Sports', 2: 'Finance', 3: 'Entertainment', 4: 'Automobile', 5: 'Technology'}, 'DBpedia': {1: 'Company', 2: 'EducationalInstitution', 3: 'Artist', 4: 'Athlete', 5: 'OfficeHolder', 6: 'MeanOfTransportation', 7: 'Building', 8: 'NaturalPlace', 9: 'Village', 10: 'Animal', 11: 'Plant', 12: 'Album', 13: 'Film', 14: 'WrittenWork'}, 'YelpReviewPolarity': {1: 'Negative polarity', 2: 'Positive polarity'}, 'YelpReviewFull': {1: 'score 1', 2: 'score 2', 3: 'score 3', 4: 'score 4', 5: 'score 5'}, 'YahooAnswers': {1: 'Society & Culture', 2: 'Science & Mathematics', 3: 'Health', 4: 'Education & Reference', 5: 'Computers & Internet', 6: 'Sports', 7: 'Business & Finance', 8: 'Entertainment & Music', 9: 'Family & Relationships', 10: 'Politics & Government'}, 'AmazonReviewPolarity': {1: 'Negative polarity', 2: 'Positive polarity'}, 'AmazonReviewFull': {1: 'score 1', 2: 'score 2', 3: 'score 3', 4: 'score 4', 5: 'score 5'} }
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources