Shortcuts

Source code for torchtext.utils

import six
import requests
import csv
from tqdm import tqdm
import os
import tarfile
import logging
import re
import sys
import zipfile


[docs]def reporthook(t): """https://github.com/tqdm/tqdm""" last_b = [0] def inner(b=1, bsize=1, tsize=None): """ b: int, optional Number of blocks just transferred [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: t.total = tsize t.update((b - last_b[0]) * bsize) last_b[0] = b return inner
[docs]def download_from_url(url, path=None, root='.data', overwrite=False): """Download file, with logic (from tensor2tensor) for Google Drive. Returns the path to the downloaded file. Arguments: url: the url of the file path: explicitly set the filename, otherwise attempts to detect the file name from URL header. (None) root: download folder used to store the file in (.data) overwrite: overwrite existing files (False) Examples: >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> torchtext.utils.download_from_url(url) >>> '.data/validation.tar.gz' """ def _process_response(r, root, filename): chunk_size = 16 * 1024 total_size = int(r.headers.get('Content-length', 0)) if filename is None: d = r.headers['content-disposition'] filename = re.findall("filename=\"(.+)\"", d) if filename is None: raise RuntimeError("Filename could not be autodetected") filename = filename[0] path = os.path.join(root, filename) if os.path.exists(path): logging.info('File %s already exists.' % path) if not overwrite: return path logging.info('Overwriting file %s.' % path) logging.info('Downloading file {} to {}.'.format(filename, path)) with open(path, "wb") as file: with tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: for chunk in r.iter_content(chunk_size): if chunk: file.write(chunk) t.update(len(chunk)) logging.info('File {} downloaded.'.format(path)) return path if path is None: _, filename = os.path.split(url) else: root, filename = os.path.split(path) if not os.path.exists(root): raise RuntimeError( "Download directory {} does not exist. " "Did you create it?".format(root)) if 'drive.google.com' not in url: response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) return _process_response(response, root, filename) else: # google drive links get filename from google drive filename = None logging.info('Downloading from Google Drive; may take a few minutes') confirm_token = None session = requests.Session() response = session.get(url, stream=True) for k, v in response.cookies.items(): if k.startswith("download_warning"): confirm_token = v if confirm_token: url = url + "&confirm=" + confirm_token response = session.get(url, stream=True) return _process_response(response, root, filename)
[docs]def unicode_csv_reader(unicode_csv_data, **kwargs): r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper. Borrowed and slightly modified from the Python docs: https://docs.python.org/2/library/csv.html#csv-examples Arguments: unicode_csv_data: unicode csv data (see example below) Examples: >>> from torchtext.utils import unicode_csv_reader >>> import io >>> with io.open(data_path, encoding="utf8") as f: >>> reader = unicode_csv_reader(f) """ # Fix field larger than field limit error maxInt = sys.maxsize while True: # decrease the maxInt value by factor 10 # as long as the OverflowError occurs. try: csv.field_size_limit(maxInt) break except OverflowError: maxInt = int(maxInt / 10) csv.field_size_limit(maxInt) if six.PY2: # csv.py doesn't do Unicode; encode temporarily as UTF-8: csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs) for row in csv_reader: # decode UTF-8 back to Unicode, cell by cell: yield [cell.decode('utf-8') for cell in row] else: for line in csv.reader(unicode_csv_data, **kwargs): yield line
def utf_8_encoder(unicode_csv_data): for line in unicode_csv_data: yield line.encode('utf-8')
[docs]def extract_archive(from_path, to_path=None, overwrite=False): """Extract archive. Arguments: from_path: the path of the archive. to_path: the root path of the extracted files (directory of from_path) overwrite: overwrite existing files (False) Returns: List of paths to extracted files even if not overwritten. Examples: >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> from_path = './validation.tar.gz' >>> to_path = './' >>> torchtext.utils.download_from_url(url, from_path) >>> torchtext.utils.extract_archive(from_path, to_path) >>> ['.data/val.de', '.data/val.en'] """ if to_path is None: to_path = os.path.dirname(from_path) if from_path.endswith(('.tar.gz', '.tgz')): logging.info('Opening tar file {}.'.format(from_path)) with tarfile.open(from_path, 'r') as tar: files = [] for file_ in tar: file_path = os.path.join(to_path, file_.name) if file_.isfile(): files.append(file_path) if os.path.exists(file_path): logging.info('{} already extracted.'.format(file_path)) if not overwrite: continue tar.extract(file_, to_path) return files elif from_path.endswith('.zip'): assert zipfile.is_zipfile(from_path), from_path logging.info('Opening zip file {}.'.format(from_path)) with zipfile.ZipFile(from_path, 'r') as zfile: files = [] for file_ in zfile.namelist(): file_path = os.path.join(to_path, file_) files.append(file_path) if os.path.exists(file_path): logging.info('{} already extracted.'.format(file_path)) if not overwrite: continue zfile.extract(file_, to_path) files = [f for f in files if os.path.isfile(f)] return files else: raise NotImplementedError( "We currently only support tar.gz, .tgz and zip achives.")
Read the Docs v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources