webdataset.dataset
Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
View Source
# # Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # """Train PyTorch models directly from POSIX tar archive. Code works locally or over HTTP connections. """ import os from . import shardcache, tariterators, utils from .composable import Composable, Processor from .utils import lookup_sym, safe_eval from .handlers import reraise_exception from .pytorch import IterableDataset, DataLoader from .shardlists import PytorchShardList, MultiShardSample from .dsspecs import construct_dataset default_cache_dir = os.path.expanduser(os.environ.get("WEBDATASET_CACHE", "")) default_cache_name = lookup_sym(os.environ.get("WEBDATASET_CACHE_NAME", "shard_uuid"), ".shardcache".split()) default_cache_verbose = int(safe_eval(os.environ.get("WEBDATASET_CACHE_VERBOSE", "1"))) default_cache_size = int(float(safe_eval(os.environ.get("WEBDATASET_CACHE_SIZE", "1e15")))) def WebDataset( urls, cache_dir=default_cache_dir, cache_size=default_cache_size, cache_name=default_cache_name, cache_verbose=default_cache_verbose, handler=reraise_exception, repeat=False, ): """Return a pipeline for WebDataset-style data files. This is a convenience function for constructing a partial pipeline that reads from a set of sharded tar files, extracts the individual files, and groups them together into samples (dictionaries). You can use all the methods from `Composable` (`then`, `compose`) and from `Shorthands` (`batched`, `unbatched`, `decode`, `shuffle`, etc.) on the result. The recommended way of specifying novel ways of splitting shards is via writing a new shardlist class. :param urls: the source URLs: a string, a list, or an IterableDataset :param handler: an error handler :param cache_dir: when set, caches shards in this directory :param cache_size: when set, specifies a maximum size for the shard cache :param cache_name: when set, specifies how shards should be named in the cache :param cache_verbose: when set, prints information about caching :param repeat: repeat infinitely if True """ if isinstance(urls, str) and urls.endswith(".ds.yml"): return construct_dataset( urls, cache_dir=cache_dir, cache_size=cache_size, cache_name=cache_name, cache_verbose=cache_verbose, handler=handler, repeat=repeat, ) if isinstance(urls, str): if urls.endswith(".shards.yml"): urls = MultiShardSample(urls) result = PytorchShardList(urls) elif isinstance(urls, list): result = PytorchShardList(urls) elif isinstance(urls, str) and os.path.splitext(urls)[1] in ["yml", "yaml", "json"]: raise ValueError("bad shard spec (only '.shards.yml' supported right now)") elif isinstance(urls, Composable): result = urls elif isinstance(urls, IterableDataset): result = urls else: return ValueError(f"{type(urls)}: unknown shard list type") result = result.then(tariterators.url_opener, handler=handler) if cache_dir != "": result = result.then( shardcache.cache_shards, cache_dir=cache_dir, cache_size=cache_size, cache_name=cache_name, verbose=cache_verbose, ) result = result.then(tariterators.tar_file_expander, handler=handler) result = result.then(tariterators.group_by_keys) if repeat: result = result.repeat() return result def WebLoader(*args, **kw): """Return a small wrapper around torch.utils.data.DataLoader. This wrapper works identically to the original `DataLoader`, but adds alls the convenience functions and filters for WebDataset. You can use all the methods from `Composable` (`then`, `compose`) and from `Shorthands` (`batched`, `unbatched`, `decode`, `shuffle`, etc.) on the result. :param args: forwarded to `DataLoader` :param kw: forwarded to `DataLoader` """ return Processor(DataLoader(*args, **kw), utils.identity)
View Source
def WebDataset( urls, cache_dir=default_cache_dir, cache_size=default_cache_size, cache_name=default_cache_name, cache_verbose=default_cache_verbose, handler=reraise_exception, repeat=False, ): """Return a pipeline for WebDataset-style data files. This is a convenience function for constructing a partial pipeline that reads from a set of sharded tar files, extracts the individual files, and groups them together into samples (dictionaries). You can use all the methods from `Composable` (`then`, `compose`) and from `Shorthands` (`batched`, `unbatched`, `decode`, `shuffle`, etc.) on the result. The recommended way of specifying novel ways of splitting shards is via writing a new shardlist class. :param urls: the source URLs: a string, a list, or an IterableDataset :param handler: an error handler :param cache_dir: when set, caches shards in this directory :param cache_size: when set, specifies a maximum size for the shard cache :param cache_name: when set, specifies how shards should be named in the cache :param cache_verbose: when set, prints information about caching :param repeat: repeat infinitely if True """ if isinstance(urls, str) and urls.endswith(".ds.yml"): return construct_dataset( urls, cache_dir=cache_dir, cache_size=cache_size, cache_name=cache_name, cache_verbose=cache_verbose, handler=handler, repeat=repeat, ) if isinstance(urls, str): if urls.endswith(".shards.yml"): urls = MultiShardSample(urls) result = PytorchShardList(urls) elif isinstance(urls, list): result = PytorchShardList(urls) elif isinstance(urls, str) and os.path.splitext(urls)[1] in ["yml", "yaml", "json"]: raise ValueError("bad shard spec (only '.shards.yml' supported right now)") elif isinstance(urls, Composable): result = urls elif isinstance(urls, IterableDataset): result = urls else: return ValueError(f"{type(urls)}: unknown shard list type") result = result.then(tariterators.url_opener, handler=handler) if cache_dir != "": result = result.then( shardcache.cache_shards, cache_dir=cache_dir, cache_size=cache_size, cache_name=cache_name, verbose=cache_verbose, ) result = result.then(tariterators.tar_file_expander, handler=handler) result = result.then(tariterators.group_by_keys) if repeat: result = result.repeat() return result
Return a pipeline for WebDataset-style data files.
This is a convenience function for constructing a partial pipeline that reads from a set of sharded tar files, extracts the individual files, and groups them together into samples (dictionaries).
You can use all the methods from Composable (then, compose) and
from Shorthands (batched, unbatched, decode, shuffle, etc.)
on the result.
The recommended way of specifying novel ways of splitting shards is via writing a new shardlist class.
:param urls: the source URLs: a string, a list, or an IterableDataset :param handler: an error handler :param cache_dir: when set, caches shards in this directory :param cache_size: when set, specifies a maximum size for the shard cache :param cache_name: when set, specifies how shards should be named in the cache :param cache_verbose: when set, prints information about caching :param repeat: repeat infinitely if True
View Source
def WebLoader(*args, **kw): """Return a small wrapper around torch.utils.data.DataLoader. This wrapper works identically to the original `DataLoader`, but adds alls the convenience functions and filters for WebDataset. You can use all the methods from `Composable` (`then`, `compose`) and from `Shorthands` (`batched`, `unbatched`, `decode`, `shuffle`, etc.) on the result. :param args: forwarded to `DataLoader` :param kw: forwarded to `DataLoader` """ return Processor(DataLoader(*args, **kw), utils.identity)
Return a small wrapper around torch.utils.data.DataLoader.
This wrapper works identically to the original DataLoader, but adds
alls the convenience functions and filters for WebDataset.
You can use all the methods from Composable (then, compose) and
from Shorthands (batched, unbatched, decode, shuffle, etc.)
on the result.
:param args: forwarded to DataLoader
:param kw: forwarded to DataLoader