Shortcuts

Source code for torchtext.data.batch

import torch


[docs]class Batch(object): """Defines a batch of examples along with its Fields. Attributes: batch_size: Number of examples in the batch. dataset: A reference to the dataset object the examples come from (which itself contains the dataset's Field objects). train: Deprecated: this attribute is left for backwards compatibility, however it is UNUSED as of the merger with pytorch 0.4. input_fields: The names of the fields that are used as input for the model target_fields: The names of the fields that are used as targets during model training Also stores the Variable for each column in the batch as an attribute. """
[docs] def __init__(self, data=None, dataset=None, device=None): """Create a Batch from a list of examples.""" if data is not None: self.batch_size = len(data) self.dataset = dataset self.fields = dataset.fields.keys() # copy field names self.input_fields = [k for k, v in dataset.fields.items() if v is not None and not v.is_target] self.target_fields = [k for k, v in dataset.fields.items() if v is not None and v.is_target] for (name, field) in dataset.fields.items(): if field is not None: batch = [getattr(x, name) for x in data] setattr(self, name, field.process(batch, device=device))
[docs] @classmethod def fromvars(cls, dataset, batch_size, train=None, **kwargs): """Create a Batch directly from a number of Variables.""" batch = cls() batch.batch_size = batch_size batch.dataset = dataset batch.fields = dataset.fields.keys() for k, v in kwargs.items(): setattr(batch, k, v) return batch
def __repr__(self): return str(self) def __str__(self): if not self.__dict__: return 'Empty {} instance'.format(torch.typename(self)) fields_to_index = filter(lambda field: field is not None, self.fields) var_strs = '\n'.join(['\t[.' + name + ']' + ":" + _short_str(getattr(self, name)) for name in fields_to_index if hasattr(self, name)]) data_str = (' from {}'.format(self.dataset.name.upper()) if hasattr(self.dataset, 'name') and isinstance(self.dataset.name, str) else '') strt = '[{} of size {}{}]\n{}'.format(torch.typename(self), self.batch_size, data_str, var_strs) return '\n' + strt def __len__(self): return self.batch_size def _get_field_values(self, fields): if len(fields) == 0: return None elif len(fields) == 1: return getattr(self, fields[0]) else: return tuple(getattr(self, f) for f in fields) def __iter__(self): yield self._get_field_values(self.input_fields) yield self._get_field_values(self.target_fields)
def _short_str(tensor): # unwrap variable to tensor if not torch.is_tensor(tensor): # (1) unpack variable if hasattr(tensor, 'data'): tensor = getattr(tensor, 'data') # (2) handle include_lengths elif isinstance(tensor, tuple): return str(tuple(_short_str(t) for t in tensor)) # (3) fallback to default str else: return str(tensor) # copied from torch _tensor_str size_str = 'x'.join(str(size) for size in tensor.size()) device_str = '' if not tensor.is_cuda else \ ' (GPU {})'.format(tensor.get_device()) strt = '[{} of size {}{}]'.format(torch.typename(tensor), size_str, device_str) return strt
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources