diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 690c3a19601..2e3728ef83a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -72,6 +72,10 @@ title: Semantic segmentation - local: object_detection title: Object detection + - local: video_load + title: Load video data + - local: video_dataset + title: Create a video dataset title: "Vision" - sections: - local: nlp_load diff --git a/docs/source/about_mapstyle_vs_iterable.mdx b/docs/source/about_mapstyle_vs_iterable.mdx index 1e9fa279e11..f794eea5714 100644 --- a/docs/source/about_mapstyle_vs_iterable.mdx +++ b/docs/source/about_mapstyle_vs_iterable.mdx @@ -139,12 +139,12 @@ But using a shuffle buffer is not enough to provide a satisfactory shuffling for ```python # Stream from the internet my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True) -my_iterable_dataset.n_shards # 39 +my_iterable_dataset.num_shards # 39 # Stream from local files data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]} my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True) -my_iterable_dataset.n_shards # 1024 +my_iterable_dataset.num_shards # 1024 # From a generator function def my_generator(n, sources): @@ -154,7 +154,7 @@ def my_generator(n, sources): gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]} my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs) -my_iterable_dataset.n_shards # 1024 +my_iterable_dataset.num_shards # 1024 ``` ## Speed differences @@ -242,5 +242,5 @@ my_iterable_dataset = my_dataset.to_iterable_dataset() If you want to shuffle your dataset or [use it with a PyTorch DataLoader](./use_with_pytorch#stream-data), we recommend generating a sharded [`IterableDataset`]: ```python my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024) -my_iterable_dataset.n_shards # 1024 +my_iterable_dataset.num_shards # 1024 ``` diff --git a/docs/source/how_to.md b/docs/source/how_to.md index 7e6cf8f719e..223a7c2c4c0 100644 --- a/docs/source/how_to.md +++ b/docs/source/how_to.md @@ -14,7 +14,7 @@ The guides are organized into six sections: - General usage: Functions for general dataset loading and processing. The functions shown in this section are applicable across all dataset modalities. - Audio: How to load, process, and share audio datasets. -- Vision: How to load, process, and share image datasets. +- Vision: How to load, process, and share image and video datasets. - Text: How to load, process, and share text datasets. - Tabular: How to load, process, and share tabular datasets. - Dataset repository: How to share and upload a dataset to the Hub. diff --git a/docs/source/nlp_load.mdx b/docs/source/nlp_load.mdx index 5cfe5d31e99..dae074ae3fc 100644 --- a/docs/source/nlp_load.mdx +++ b/docs/source/nlp_load.mdx @@ -33,4 +33,14 @@ To load remote text files via HTTP, pass the URLs instead: ```py >>> dataset = load_dataset("text", data_files="https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt") -``` \ No newline at end of file +``` + +To load XML data you can use the "xml" loader, which is equivalent to "text" with sample_by="document": + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset("xml", data_files={"train": ["my_xml_1.xml", "my_xml_2.xml"], "test": "my_xml_file.xml"}) + +# Load from a directory +>>> dataset = load_dataset("xml", data_dir="path/to/xml/dataset") +``` diff --git a/docs/source/package_reference/loading_methods.mdx b/docs/source/package_reference/loading_methods.mdx index b17cbed8a3b..b54ca4fa1eb 100644 --- a/docs/source/package_reference/loading_methods.mdx +++ b/docs/source/package_reference/loading_methods.mdx @@ -49,6 +49,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t") [[autodoc]] datasets.packaged_modules.json.Json +### XML + +[[autodoc]] datasets.packaged_modules.xml.XmlConfig + +[[autodoc]] datasets.packaged_modules.xml.Xml + ### Parquet [[autodoc]] datasets.packaged_modules.parquet.ParquetConfig @@ -79,6 +85,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t") [[autodoc]] datasets.packaged_modules.audiofolder.AudioFolder +### Videos + +[[autodoc]] datasets.packaged_modules.videofolder.VideoFolderConfig + +[[autodoc]] datasets.packaged_modules.videofolder.VideoFolder + ### WebDataset [[autodoc]] datasets.packaged_modules.webdataset.WebDataset diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index c08d555cd29..185bde10d72 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -171,6 +171,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth - batch - skip - take + - shard - load_state_dict - state_dict - info @@ -245,6 +246,10 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable [[autodoc]] datasets.Image +### Video + +[[autodoc]] datasets.Video + ## Filesystems [[autodoc]] datasets.filesystems.is_remote_filesystem diff --git a/docs/source/process.mdx b/docs/source/process.mdx index 11fd4a1e9b0..38989613ef3 100644 --- a/docs/source/process.mdx +++ b/docs/source/process.mdx @@ -736,7 +736,7 @@ Want to save your dataset to a cloud storage provider? Read our [Cloud Storage]( | JSON | [`Dataset.to_json`] | | Parquet | [`Dataset.to_parquet`] | | SQL | [`Dataset.to_sql`] | -| In-memory Python object | [`Dataset.to_pandas`] or [`Dataset.to_dict`] | +| In-memory Python object | [`Dataset.to_pandas`], [`Dataset.to_polars`] or [`Dataset.to_dict`] | For example, export your dataset to a CSV file like this: diff --git a/docs/source/share.mdx b/docs/source/share.mdx index cb3d401fed5..0528d6dc917 100644 --- a/docs/source/share.mdx +++ b/docs/source/share.mdx @@ -9,7 +9,7 @@ Dataset repositories offer features such as: - Commit history and diffs - Metadata for discoverability - Dataset cards for documentation, licensing, limitations, etc. -- [Dataset Viewer](../hub/datasets-viewer) +- [Dataset Viewer](https://huggingface.co/docs/hub/datasets-viewer) This guide will show you how to share a dataset folder or repository that can be easily accessed by anyone. @@ -68,13 +68,13 @@ Check your directory to ensure the only files you're uploading are: ## huggingface-cli upload -Use the `huggingface-cli upload` command to upload files to the Hub directly. Internally, it uses the same [`upload_file`] and [`upload_folder`] helpers described in the [Upload guide](../huggingface_hub/guides/upload). In the examples below, we will walk through the most common use cases. For a full list of available options, you can run: +Use the `huggingface-cli upload` command to upload files to the Hub directly. Internally, it uses the same [`upload_file`] and [`upload_folder`] helpers described in the [Upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload). In the examples below, we will walk through the most common use cases. For a full list of available options, you can run: ```bash >>> huggingface-cli upload --help ``` -For more general information about `huggingface-cli` you can check the [CLI guide](../huggingface_hub/guides/cli). +For more general information about `huggingface-cli` you can check the [CLI guide](https://huggingface.co/docs/huggingface_hub/guides/cli). ### Upload an entire folder @@ -214,6 +214,6 @@ Congratulations, your dataset has now been uploaded to the Hugging Face Hub wher dataset = load_dataset("Wauplin/my-cool-dataset") ``` -If your dataset is supported, it should also have a [Dataset Viewer](../hub/datasets-viewer) for everyone to explore the dataset content. +If your dataset is supported, it should also have a [Dataset Viewer](https://huggingface.co/docs/hub/datasets-viewer) for everyone to explore the dataset content. Finally, don't forget to enrich the dataset card to document your dataset and make it discoverable! Check out the [Create a dataset card](dataset_card) guide to learn more. diff --git a/docs/source/stream.mdx b/docs/source/stream.mdx index df5109b25aa..0be393ce4a8 100644 --- a/docs/source/stream.mdx +++ b/docs/source/stream.mdx @@ -136,6 +136,35 @@ You can split your dataset one of two ways: + +### Shard + +🤗 Datasets supports sharding to divide a very large dataset into a predefined number of chunks. Specify the `num_shards` parameter in [`~IterableDataset.shard`] to determine the number of shards to split the dataset into. You'll also need to provide the shard you want to return with the `index` parameter. + +For example, the [amazon_polarity](https://huggingface.co/datasets/amazon_polarity) dataset has 4 shards (in this case they are 4 Parquet files): + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset("amazon_polarity", split="train", streaming=True) +>>> print(dataset) +IterableDataset({ + features: ['label', 'title', 'content'], + num_shards: 4 +}) +``` + +After sharding the dataset into two chunks, the first one will only have 2 shards: + +```py +>>> dataset.shard(num_shards=2, index=0) +IterableDataset({ + features: ['label', 'title', 'content'], + num_shards: 2 +}) +``` + +If your dataset has `dataset.num_shards==1`, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead. + ## Interleave [`interleave_datasets`] can combine an [`IterableDataset`] with other datasets. The combined dataset returns alternating examples from each of the original datasets. diff --git a/docs/source/use_with_pytorch.mdx b/docs/source/use_with_pytorch.mdx index 7f78d8de05c..375d6facc3e 100644 --- a/docs/source/use_with_pytorch.mdx +++ b/docs/source/use_with_pytorch.mdx @@ -216,7 +216,7 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi ```py >>> my_iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train") ->>> my_iterable_dataset.n_shards +>>> my_iterable_dataset.num_shards 39 >>> dataloader = DataLoader(my_iterable_dataset, batch_size=32, num_workers=4) ``` @@ -259,7 +259,7 @@ Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of t For iterable datasets: -If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), +If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. diff --git a/docs/source/video_dataset.mdx b/docs/source/video_dataset.mdx new file mode 100644 index 00000000000..79cefbd294b --- /dev/null +++ b/docs/source/video_dataset.mdx @@ -0,0 +1,172 @@ +# Create a video dataset + +This guide will show you how to create a video dataset with `VideoFolder` and some metadata. This is a no-code solution for quickly creating a video dataset with several thousand videos. + + + +You can control access to your dataset by requiring users to share their contact information first. Check out the [Gated datasets](https://huggingface.co/docs/hub/datasets-gated) guide for more information about how to enable this feature on the Hub. + + + +## VideoFolder + +The `VideoFolder` is a dataset builder designed to quickly load a video dataset with several thousand videos without requiring you to write any code. + + + +💡 Take a look at the [Split pattern hierarchy](repository_structure#split-pattern-hierarchy) to learn more about how `VideoFolder` creates dataset splits based on your dataset repository structure. + + + +`VideoFolder` automatically infers the class labels of your dataset based on the directory name. Store your dataset in a directory structure like: + +``` +folder/train/dog/golden_retriever.mp4 +folder/train/dog/german_shepherd.mp4 +folder/train/dog/chihuahua.mp4 + +folder/train/cat/maine_coon.mp4 +folder/train/cat/bengal.mp4 +folder/train/cat/birman.mp4 +``` + +Then users can load your dataset by specifying `videofolder` in [`load_dataset`] and the directory in `data_dir`: + +```py +>>> from datasets import load_dataset + +>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder") +``` + +You can also use `videofolder` to load datasets involving multiple splits. To do so, your dataset directory should have the following structure: + +``` +folder/train/dog/golden_retriever.mp4 +folder/train/cat/maine_coon.mp4 +folder/test/dog/german_shepherd.mp4 +folder/test/cat/bengal.mp4 +``` + + + +If all video files are contained in a single directory or if they are not on the same level of directory structure, `label` column won't be added automatically. If you need it, set `drop_labels=False` explicitly. + + + + +If there is additional information you'd like to include about your dataset, like text captions or bounding boxes, add it as a `metadata.csv` file in your folder. This lets you quickly create datasets for different computer vision tasks like text captioning or object detection. You can also use a JSONL file `metadata.jsonl`. + +``` +folder/train/metadata.csv +folder/train/0001.mp4 +folder/train/0002.mp4 +folder/train/0003.mp4 +``` + +You can also zip your videos: + +``` +folder/metadata.csv +folder/train.zip +folder/test.zip +folder/valid.zip +``` + +Your `metadata.csv` file must have a `file_name` column which links video files with their metadata: + +```csv +file_name,additional_feature +0001.mp4,This is a first value of a text feature you added to your videos +0002.mp4,This is a second value of a text feature you added to your videos +0003.mp4,This is a third value of a text feature you added to your videos +``` + +or using `metadata.jsonl`: + +```jsonl +{"file_name": "0001.mp4", "additional_feature": "This is a first value of a text feature you added to your videos"} +{"file_name": "0002.mp4", "additional_feature": "This is a second value of a text feature you added to your videos"} +{"file_name": "0003.mp4", "additional_feature": "This is a third value of a text feature you added to your videos"} +``` + + + +If metadata files are present, the inferred labels based on the directory name are dropped by default. To include those labels, set `drop_labels=False` in `load_dataset`. + + + +### Video captioning + +Video captioning datasets have text describing a video. An example `metadata.csv` may look like: + +```csv +file_name,text +0001.mp4,This is a golden retriever playing with a ball +0002.mp4,A german shepherd +0003.mp4,One chihuahua +``` + +Load the dataset with `VideoFolder`, and it will create a `text` column for the video captions: + +```py +>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder", split="train") +>>> dataset[0]["text"] +"This is a golden retriever playing with a ball" +``` + +### Upload dataset to the Hub + +Once you've created a dataset, you can share it to the using `huggingface_hub` for example. Make sure you have the [huggingface_hub](https://huggingface.co/docs/huggingface_hub/index) library installed and you're logged in to your Hugging Face account (see the [Upload with Python tutorial](upload_dataset#upload-with-python) for more details). + +Upload your dataset with `huggingface_hub.HfApi.upload_folder`: + +```py +from huggingface_hub import HfApi +api = HfApi() + +api.upload_folder( + folder_path="/path/to/local/dataset", + repo_id="username/my-cool-dataset", + repo_type="dataset", +) +``` + +## WebDataset + +The [WebDataset](https://github.com/webdataset/webdataset) format is based on TAR archives and is suitable for big video datasets. +Indeed you can group your videos in TAR archives (e.g. 1GB of videos per TAR archive) and have thousands of TAR archives: + +``` +folder/train/00000.tar +folder/train/00001.tar +folder/train/00002.tar +... +``` + +In the archives, each example is made of files sharing the same prefix: + +``` +e39871fd9fd74f55.mp4 +e39871fd9fd74f55.json +f18b91585c4d3f3e.mp4 +f18b91585c4d3f3e.json +ede6e66b2fb59aab.mp4 +ede6e66b2fb59aab.json +ed600d57fcee4f94.mp4 +ed600d57fcee4f94.json +... +``` + +You can put your videos labels/captions/features using JSON or text files for example. + +For more details on the WebDataset format and the python library, please check the [WebDataset documentation](https://webdataset.github.io/webdataset). + +Load your WebDataset and it will create on column per file suffix (here "mp4" and "json"): + +```python +>>> from datasets import load_dataset + +>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", split="train") +>>> dataset[0]["json"] +{"bbox": [[302.0, 109.0, 73.0, 52.0]], "categories": [0]} +``` diff --git a/docs/source/video_load.mdx b/docs/source/video_load.mdx new file mode 100644 index 00000000000..782869d2eed --- /dev/null +++ b/docs/source/video_load.mdx @@ -0,0 +1,132 @@ +# Load video data + + + +Video support is experimental and is subject to change. + + + +Video datasets have [`Video`] type columns, which contain `decord` objects. + + + +To work with video datasets, you need to have the `vision` dependency installed. Check out the [installation](./installation#vision) guide to learn how to install it. + + + +When you load an video dataset and call the video column, the videos are decoded as `decord` Videos: + +```py +>>> from datasets import load_dataset, Video + +>>> dataset = load_dataset("path/to/video/folder", split="train") +>>> dataset[0]["video"] + +``` + + + +Index into an video dataset using the row index first and then the `video` column - `dataset[0]["video"]` - to avoid reading all the video objects in the dataset. Otherwise, this can be a slow and time-consuming process if you have a large dataset. + + + +For a guide on how to load any type of dataset, take a look at the general loading guide. + +## Read frames + +Access frames directly from a video using the `VideoReader`: + +```python +>>> dataset[0]["video"][0].shape # first frame +(240, 320, 3) +``` + +To get multiple frames at once, use `get_batch`. This is the efficient way to obtain a long list of frames: + +```python +>>> frames = dataset[0]["video"].get_batch([1, 3, 5, 7, 9]) +>>> frames.shape +(5, 240, 320, 3) +``` + +## Local files + +You can load a dataset from the video path. Use the [`~Dataset.cast_column`] function to accept a column of video file paths, and decode it into a `decord` video with the [`Video`] feature: +```py +>>> from datasets import Dataset, Video + +>>> dataset = Dataset.from_dict({"video": ["path/to/video_1", "path/to/video_2", ..., "path/to/video_n"]}).cast_column("video", Video()) +>>> dataset[0]["video"] + +``` + +If you only want to load the underlying path to the video dataset without decoding the video object, set `decode=False` in the [`Video`] feature: + +```py +>>> dataset = dataset.cast_column("video", Video(decode=False)) +>>> dataset[0]["video"] +{'bytes': None, + 'path': 'path/to/video/folder/video0.mp4'} +``` + +## VideoFolder + +You can also load a dataset with an `VideoFolder` dataset builder which does not require writing a custom dataloader. This makes `VideoFolder` ideal for quickly creating and loading video datasets with several thousand videos for different vision tasks. Your video dataset structure should look like this: + +``` +folder/train/dog/golden_retriever.mp4 +folder/train/dog/german_shepherd.mp4 +folder/train/dog/chihuahua.mp4 + +folder/train/cat/maine_coon.mp4 +folder/train/cat/bengal.mp4 +folder/train/cat/birman.mp4 +``` + +Load your dataset by specifying `videofolder` and the directory of your dataset in `data_dir`: + +```py +>>> from datasets import load_dataset + +>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder") +>>> dataset["train"][0] +{"video": , "label": 0} + +>>> dataset["train"][-1] +{"video": , "label": 1} +``` + +Load remote datasets from their URLs with the `data_files` parameter: + +```py +>>> dataset = load_dataset("videofolder", data_files="https://foo.bar/videos.zip", split="train") +``` + +Some datasets have a metadata file (`metadata.csv`/`metadata.jsonl`) associated with it, containing other information about the data like bounding boxes, text captions, and labels. The metadata is automatically loaded when you call [`load_dataset`] and specify `videofolder`. + +To ignore the information in the metadata file, set `drop_labels=False` in [`load_dataset`], and allow `VideoFolder` to automatically infer the label name from the directory name: + +```py +>>> from datasets import load_dataset + +>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder", drop_labels=False) +``` + + + +For more information about creating your own `VideoFolder` dataset, take a look at the [Create a video dataset](./video_dataset) guide. + + + +## WebDataset + +The [WebDataset](https://github.com/webdataset/webdataset) format is based on a folder of TAR archives and is suitable for big video datasets. +Because of their size, WebDatasets are generally loaded in streaming mode (using `streaming=True`). + +You can load a WebDataset like this: + +```python +>>> from datasets import load_dataset + +>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True) +``` diff --git a/setup.py b/setup.py index 7c1391c1b61..65ea27eb731 100644 --- a/setup.py +++ b/setup.py @@ -187,6 +187,7 @@ "transformers>=4.42.0", # Pins numpy < 2 "zstandard", "polars[timezone]>=0.20.0", + "decord==0.6.0", ] @@ -234,7 +235,7 @@ setup( name="datasets", - version="3.0.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="3.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="HuggingFace community-driven open-source library of datasets", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py index 6bb25aadabc..4691019efe1 100644 --- a/src/datasets/__init__.py +++ b/src/datasets/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "3.0.2" +__version__ = "3.1.0" from .arrow_dataset import Dataset from .arrow_reader import ReadInstruction diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index b289fba4106..57f3024e53b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -77,7 +77,7 @@ from .arrow_writer import ArrowWriter, OptimizedTypedSequence from .data_files import sanitize_patterns from .download.streaming_download_manager import xgetsize -from .features import Audio, ClassLabel, Features, Image, Sequence, Value +from .features import Audio, ClassLabel, Features, Image, Sequence, Value, Video from .features.features import ( FeatureType, _align_features, @@ -303,7 +303,7 @@ def _get_output_signature( tf_dtype = tf.float32 np_dtype = np.float32 elif np_arrays[0].dtype.kind == "U": # Unicode strings - np_dtype = np.unicode_ + np_dtype = np.str_ tf_dtype = tf.string else: raise RuntimeError( @@ -1416,9 +1416,9 @@ def save_to_disk( """ Saves a dataset to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. - For [`Image`] and [`Audio`] data: + For [`Image`], [`Audio`] and [`Video`] data: - All the Image() and Audio() data are stored in the arrow files. + All the Image(), Audio() and Video() data are stored in the arrow files. If you want to store paths or urls, please use the Value("string") type. Args: @@ -4630,32 +4630,31 @@ def shard( self, num_shards: int, index: int, - contiguous: bool = False, + contiguous: bool = True, keep_in_memory: bool = False, indices_cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, ) -> "Dataset": """Return the `index`-nth shard from dataset split into `num_shards` pieces. - This shards deterministically. `dset.shard(n, i)` will contain all elements of dset whose - index mod `n = i`. + This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks, + so it can be easily concatenated back together after processing. If `len(dataset) % n == l`, then the + first `l` dataset each have length `(len(dataset) // n) + 1`, and the remaining dataset have length `(len(dataset) // n)`. + `datasets.concatenate_datasets([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original. - `dset.shard(n, i, contiguous=True)` will instead split dset into contiguous chunks, - so it can be easily concatenated back together after processing. If `n % i == l`, then the - first `l` shards will have length `(n // i) + 1`, and the remaining shards will have length `(n // i)`. - `datasets.concatenate([dset.shard(n, i, contiguous=True) for i in range(n)])` will return - a dataset with the same order as the original. + Note: n should be less or equal to the number of elements in the dataset `len(dataset)`. + + On the other hand, `dataset.shard(n, i, contiguous=False)` contains all elements of the dataset whose index mod `n = i`. Be sure to shard before using any randomizing operator (such as `shuffle`). It is best if the shard operator is used early in the dataset pipeline. - Args: num_shards (`int`): How many shards to split the dataset into. index (`int`): Which shard to select and return. - contiguous: (`bool`, defaults to `False`): + contiguous: (`bool`, defaults to `True`): Whether to select contiguous blocks of indices for shards. keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. @@ -4663,7 +4662,8 @@ def shard( Provide the name of a path for the cache file. It is used to store the indices of each shard instead of the automatically generated cache file name. writer_batch_size (`int`, defaults to `1000`): - Number of rows per write operation for the cache file writer. + This only concerns the indices mapping. + Number of indices per write operation for the cache file writer. This value is a good trade-off between memory usage during the processing, and processing speed. Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running `map`. @@ -4729,7 +4729,7 @@ def to_csv( **to_csv_kwargs (additional keyword arguments): - Parameters to pass to pandas's [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_json.html). + Parameters to pass to pandas's [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html). @@ -5065,7 +5065,7 @@ def _estimate_nbytes(self) -> int: def extra_nbytes_visitor(array, feature): nonlocal extra_nbytes - if isinstance(feature, (Audio, Image)): + if isinstance(feature, (Audio, Image, Video)): for x in array.to_pylist(): if x is not None and x["bytes"] is None and x["path"] is not None: size = xgetsize(x["path"]) @@ -5249,15 +5249,16 @@ def _push_parquet_shards_to_hub( shards = (self.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards)) if decodable_columns: + from .io.parquet import get_writer_batch_size - def shards_with_embedded_external_files(shards): + def shards_with_embedded_external_files(shards: Iterator[Dataset]) -> Iterator[Dataset]: for shard in shards: format = shard.format shard = shard.with_format("arrow") shard = shard.map( embed_table_storage, batched=True, - batch_size=1000, + batch_size=get_writer_batch_size(shard.features), keep_in_memory=True, ) shard = shard.with_format(**format) @@ -5310,7 +5311,7 @@ def push_to_hub( """Pushes the dataset to the hub as a Parquet dataset. The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed. - The resulting Parquet files are self-contained by default. If your dataset contains [`Image`] or [`Audio`] + The resulting Parquet files are self-contained by default. If your dataset contains [`Image`], [`Audio`] or [`Video`] data, the Parquet files will store the bytes of your images or audio files. You can disable this by setting `embed_external_files` to `False`. @@ -5399,6 +5400,13 @@ def push_to_hub( >>> french_dataset = load_dataset("/", "fr") ``` """ + if "Video(" in str(self.features): + raise NotImplementedError( + "push_to_hub is not implemented for video datasets, instead you should upload the video files " + "using e.g. the huggingface_hub library and optionally upload a metadata.csv or metadata.jsonl " + "file containing other information like video captions, features or labels. More information " + "at https://huggingface.co/docs/datasets/main/en/video_load#videofolder" + ) if config_name == "data": raise ValueError("`config_name` cannot be 'data'. Please, choose another name for configuration.") diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 3b9993736e4..23fd8b94b87 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -24,10 +24,11 @@ from fsspec.core import url_to_fs from . import config -from .features import Features, Image, Value +from .features import Audio, Features, Image, Value, Video from .features.features import ( FeatureType, _ArrayXDExtensionType, + _visit, cast_to_python_objects, generate_from_arrow_type, get_nested_type, @@ -48,6 +49,45 @@ type_ = type # keep python's type function +def get_writer_batch_size(features: Optional[Features]) -> Optional[int]: + """ + Get the writer_batch_size that defines the maximum row group size in the parquet files. + The default in `datasets` is 1,000 but we lower it to 100 for image/audio datasets and 10 for videos. + This allows to optimize random access to parquet file, since accessing 1 row requires + to read its entire row group. + + This can be improved to get optimized size for querying/iterating + but at least it matches the dataset viewer expectations on HF. + + Args: + features (`datasets.Features` or `None`): + Dataset Features from `datasets`. + Returns: + writer_batch_size (`Optional[int]`): + Writer batch size to pass to a dataset builder. + If `None`, then it will use the `datasets` default. + """ + if not features: + return None + + batch_size = np.inf + + def set_batch_size(feature: FeatureType) -> None: + nonlocal batch_size + if isinstance(feature, Image): + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS) + elif isinstance(feature, Audio): + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS) + elif isinstance(feature, Video): + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS) + elif isinstance(feature, Value) and feature.dtype == "binary": + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS) + + _visit(features, set_batch_size) + + return None if batch_size is np.inf else batch_size + + class SchemaInferenceError(ValueError): pass @@ -340,7 +380,9 @@ def __init__( self.fingerprint = fingerprint self.disable_nullable = disable_nullable - self.writer_batch_size = writer_batch_size or config.DEFAULT_MAX_BATCH_SIZE + self.writer_batch_size = ( + writer_batch_size or get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE + ) self.update_features = update_features self.with_metadata = with_metadata self.unit = unit diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 7328b90cbca..c3eee41c6e0 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -930,7 +930,8 @@ def incomplete_dir(dirname): # Sync info self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) self.info.download_checksums = dl_manager.get_recorded_sizes_checksums() - self.info.size_in_bytes = self.info.dataset_size + self.info.download_size + if self.info.download_size is not None: + self.info.size_in_bytes = self.info.dataset_size + self.info.download_size # Save info self._save_info() diff --git a/src/datasets/config.py b/src/datasets/config.py index e2de170bcba..43801efcaef 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -67,6 +67,17 @@ except importlib.metadata.PackageNotFoundError: pass + +DUCKDB_VERSION = "N/A" +DUCKDB_AVAILABLE = importlib.util.find_spec("duckdb") is not None + +if DUCKDB_AVAILABLE: + try: + DUCKDB_VERSION = version.parse(importlib.metadata.version("duckdb")) + logger.info(f"Duckdb version {DUCKDB_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass + TF_VERSION = "N/A" TF_AVAILABLE = False @@ -129,6 +140,7 @@ IS_MP3_SUPPORTED = importlib.util.find_spec("soundfile") is not None and version.parse( importlib.import_module("soundfile").__libsndfile_version__ ) >= version.parse("1.1.0") +DECORD_AVAILABLE = importlib.util.find_spec("decord") is not None # Optional compression tools RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None @@ -192,6 +204,7 @@ PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS = 100 PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = 100 PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS = 100 +PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS = 10 # Offline mode _offline = os.environ.get("HF_DATASETS_OFFLINE") diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index f92a1a8afda..40ca0cd7312 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1230,9 +1230,9 @@ def save_to_disk( """ Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`. - For [`Image`] and [`Audio`] data: + For [`Image`], [`Audio`] and [`Video`] data: - All the Image() and Audio() data are stored in the arrow files. + All the Image(), Audio() and Video() data are stored in the arrow files. If you want to store paths or urls, please use the Value("string") type. Args: diff --git a/src/datasets/distributed.py b/src/datasets/distributed.py index e036fabaf2c..4697948f342 100644 --- a/src/datasets/distributed.py +++ b/src/datasets/distributed.py @@ -18,7 +18,7 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D For iterable datasets: - If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), + If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index 9a1d1e3a53c..800f6821443 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -64,6 +64,8 @@ def __init__( self._data_dir = data_dir self._base_path = base_path or os.path.abspath(".") self.download_config = download_config or DownloadConfig() + self.downloaded_size = None + self.record_checksums = False @property def manual_dir(self): @@ -208,3 +210,9 @@ def iter_files(self, urlpaths: Union[str, List[str]]) -> Iterable[str]: ``` """ return FilesIterable.from_urlpaths(urlpaths, download_config=self.download_config) + + def manage_extracted_files(self): + pass + + def get_recorded_sizes_checksums(self): + pass diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index 35ebfb4ac0c..bf38042eb81 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -12,8 +12,10 @@ "Image", "Translation", "TranslationVariableLanguages", + "Video", ] from .audio import Audio from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value from .image import Image from .translation import Translation, TranslationVariableLanguages +from .video import Video diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 1d241e0b7b7..34622cd94d9 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -43,6 +43,7 @@ from .audio import Audio from .image import Image, encode_pil_image from .translation import Translation, TranslationVariableLanguages +from .video import Video logger = logging.get_logger(__name__) @@ -1202,6 +1203,7 @@ class LargeList: Array5D, Audio, Image, + Video, ] @@ -1346,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0): return list(obj) # Object with special encoding: # ClassLabel will convert from string to int, TranslationVariableLanguages does some checks - elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD)): + elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)): return schema.encode_example(obj) if obj is not None else None # Other object should be directly convertible to a native Arrow type (like Translation and Translation) return obj @@ -1397,7 +1399,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni else: return decode_nested_example([schema.feature], obj) # Object with special decoding: - elif isinstance(schema, (Audio, Image)): + elif isinstance(schema, (Audio, Image, Video)): # we pass the token to read and decode files from private repositories in streaming mode if obj is not None and schema.decode: return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) @@ -1417,6 +1419,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni Array5D.__name__: Array5D, Audio.__name__: Audio, Image.__name__: Image, + Video.__name__: Video, } diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py new file mode 100644 index 00000000000..2cde83930ac --- /dev/null +++ b/src/datasets/features/video.py @@ -0,0 +1,296 @@ +import os +from dataclasses import dataclass, field +from io import BytesIO +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union + +import numpy as np +import pyarrow as pa + +from .. import config +from ..download.download_config import DownloadConfig +from ..table import array_cast +from ..utils.file_utils import is_local_path, xopen +from ..utils.py_utils import string_to_dict + + +if TYPE_CHECKING: + from decord import VideoReader + + from .features import FeatureType + + +@dataclass +class Video: + """ + **Experimental.** Video [`Feature`] to read video data from a video file. + + Input: The Video feature accepts as input: + - A `str`: Absolute path to the video file (i.e. random access is allowed). + - A `dict` with the keys: + + - `path`: String with relative path of the video file in a dataset repository. + - `bytes`: Bytes of the video file. + + This is useful for archived files with sequential access. + + - A `decord.VideoReader`: decord video reader object. + + Args: + mode (`str`, *optional*): + The mode to convert the video to. If `None`, the native mode of the video is used. + decode (`bool`, defaults to `True`): + Whether to decode the video data. If `False`, + returns the underlying dictionary in the format `{"path": video_path, "bytes": video_bytes}`. + + Examples: + + ```py + >>> from datasets import Dataset, Video + >>> ds = Dataset.from_dict({"video":["path/to/Screen Recording.mov"]}).cast_column("video", Video()) + >>> ds.features["video"] + Video(decode=True, id=None) + >>> ds[0]["video"] + + >>> ds = ds.cast_column('video', Video(decode=False)) + {'bytes': None, + 'path': 'path/to/Screen Recording.mov'} + ``` + """ + + decode: bool = True + id: Optional[str] = None + # Automatically constructed + dtype: ClassVar[str] = "decord.VideoReader" + pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()}) + _type: str = field(default="Video", init=False, repr=False) + + def __post_init__(self): + if config.DECORD_AVAILABLE: + patch_decord() + + def __call__(self): + return self.pa_type + + def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader"]) -> dict: + """Encode example into a format for Arrow. + + Args: + value (`str`, `np.ndarray`, `VideoReader` or `dict`): + Data passed as input to Video feature. + + Returns: + `dict` with "path" and "bytes" fields + """ + if config.DECORD_AVAILABLE: + from decord import VideoReader + + else: + VideoReader = None + + if isinstance(value, list): + value = np.array(value) + + if isinstance(value, str): + return {"path": value, "bytes": None} + elif isinstance(value, bytes): + return {"path": None, "bytes": value} + elif isinstance(value, np.ndarray): + # convert the video array to bytes + return encode_np_array(value) + elif VideoReader and isinstance(value, VideoReader): + # convert the decord video reader to bytes + return encode_decord_video(value) + elif value.get("path") is not None and os.path.isfile(value["path"]): + # we set "bytes": None to not duplicate the data if they're already available locally + return {"bytes": None, "path": value.get("path")} + elif value.get("bytes") is not None or value.get("path") is not None: + # store the video bytes, and path is used to infer the video format using the file extension + return {"bytes": value.get("bytes"), "path": value.get("path")} + else: + raise ValueError( + f"A video sample should have one of 'path' or 'bytes' but they are missing or None in {value}." + ) + + def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader": + """Decode example video file into video data. + + Args: + value (`str` or `dict`): + A string with the absolute video file path, a dictionary with + keys: + + - `path`: String with absolute or relative video file path. + - `bytes`: The bytes of the video file. + token_per_repo_id (`dict`, *optional*): + To access and decode + video files from private repositories on the Hub, you can pass + a dictionary repo_id (`str`) -> token (`bool` or `str`). + + Returns: + `decord.VideoReader` + """ + if not self.decode: + raise RuntimeError("Decoding is disabled for this feature. Please use Video(decode=True) instead.") + + if config.DECORD_AVAILABLE: + from decord import VideoReader + + else: + raise ImportError("To support decoding videos, please install 'decord'.") + + if token_per_repo_id is None: + token_per_repo_id = {} + + path, bytes_ = value["path"], value["bytes"] + if bytes_ is None: + if path is None: + raise ValueError(f"A video should have one of 'path' or 'bytes' but both are None in {value}.") + else: + if is_local_path(path): + video = VideoReader(path) + else: + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL + if source_url.startswith(config.HF_ENDPOINT) + else config.HUB_DATASETS_HFFS_URL + ) + try: + repo_id = string_to_dict(source_url, pattern)["repo_id"] + token = token_per_repo_id.get(repo_id) + except ValueError: + token = None + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + bytes_ = BytesIO(f.read()) + video = VideoReader(bytes_) + else: + video = VideoReader(BytesIO(bytes_)) + return video + + def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]: + """If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary.""" + from .features import Value + + return ( + self + if self.decode + else { + "bytes": Value("binary"), + "path": Value("string"), + } + ) + + def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArray]) -> pa.StructArray: + """Cast an Arrow array to the Video arrow storage type. + The Arrow types that can be converted to the Video pyarrow storage type are: + + - `pa.string()` - it must contain the "path" data + - `pa.binary()` - it must contain the video bytes + - `pa.struct({"bytes": pa.binary()})` + - `pa.struct({"path": pa.string()})` + - `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter + - `pa.list(*)` - it must contain the video array data + + Args: + storage (`Union[pa.StringArray, pa.StructArray, pa.ListArray]`): + PyArrow array to cast. + + Returns: + `pa.StructArray`: Array in the Video arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + if pa.types.is_string(storage.type): + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_binary(storage.type): + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_list(storage.type): + bytes_array = pa.array( + [encode_np_array(np.array(arr))["bytes"] if arr is not None else None for arr in storage.to_pylist()], + type=pa.binary(), + ) + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays( + [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null() + ) + return array_cast(storage, self.pa_type) + + +def video_to_bytes(video: "VideoReader") -> bytes: + """Convert a decord Video object to bytes using native compression if possible""" + raise NotImplementedError() + + +def encode_decord_video(video: "VideoReader") -> dict: + if hasattr(video, "_hf_encoded"): + return video._hf_encoded + else: + raise NotImplementedError( + "Encoding a decord video is not implemented. " + "Please call `datasets.features.video.patch_decord()` before loading videos to enable this." + ) + + +def encode_np_array(array: np.ndarray) -> dict: + raise NotImplementedError() + + +# Patching decord a little bit to: +# 1. store the encoded video data {"path": ..., "bytes": ...} in `video._hf_encoded`` +# 2. set the decord bridge to numpy/torch/tf/jax using `video._hf_bridge_out` (per video instance) instead of decord.bridge.bridge_out (global) +# This doesn't affect the normal usage of decord. + + +def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None: + from decord.bridge import bridge_out + + if hasattr(uri, "read"): + self._hf_encoded = {"bytes": uri.read(), "path": None} + uri.seek(0) + elif isinstance(uri, str): + self._hf_encoded = {"bytes": None, "path": uri} + self._hf_bridge_out = bridge_out + self._original_init(uri, *args, **kwargs) + + +def _patched_next(self: "VideoReader", *args, **kwargs): + return self._hf_bridge_out(self._original_next(*args, **kwargs)) + + +def _patched_get_batch(self: "VideoReader", *args, **kwargs): + return self._hf_bridge_out(self._original_get_batch(*args, **kwargs)) + + +def patch_decord(): + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + # Same for duckdb which crashes on import + if config.TORCH_AVAILABLE: + import torch # noqa + if config.DUCKDB_AVAILABLE: + import duckdb # noqa + import decord.video_reader + from decord import VideoReader + + if not hasattr(VideoReader, "_hf_patched"): + decord.video_reader.bridge_out = lambda x: x + VideoReader._original_init = VideoReader.__init__ + VideoReader.__init__ = _patched_init + VideoReader._original_next = VideoReader.next + VideoReader.next = _patched_next + VideoReader._original_get_batch = VideoReader.get_batch + VideoReader.get_batch = _patched_get_batch + VideoReader._hf_patched = True diff --git a/src/datasets/formatting/jax_formatter.py b/src/datasets/formatting/jax_formatter.py index 8035341c5cd..e247b7b5822 100644 --- a/src/datasets/formatting/jax_formatter.py +++ b/src/datasets/formatting/jax_formatter.py @@ -100,11 +100,23 @@ def _tensorize(self, value): default_dtype = {"dtype": jnp.int32} elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": jnp.float32} - elif config.PIL_AVAILABLE and "PIL" in sys.modules: + + if config.PIL_AVAILABLE and "PIL" in sys.modules: import PIL.Image if isinstance(value, PIL.Image.Image): value = np.asarray(value) + if config.DECORD_AVAILABLE and "decord" in sys.modules: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + from decord import VideoReader + + if isinstance(value, VideoReader): + value._hf_bridge_out = lambda x: jnp.array(np.asarray(x)) + return value # using global variable since `jaxlib.xla_extension.Device` is not serializable neither # with `pickle` nor with `dill`, so we need to use a global variable instead diff --git a/src/datasets/formatting/np_formatter.py b/src/datasets/formatting/np_formatter.py index 95bcff2b517..032758bce21 100644 --- a/src/datasets/formatting/np_formatter.py +++ b/src/datasets/formatting/np_formatter.py @@ -57,11 +57,23 @@ def _tensorize(self, value): default_dtype = {"dtype": np.int64} elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": np.float32} - elif config.PIL_AVAILABLE and "PIL" in sys.modules: + + if config.PIL_AVAILABLE and "PIL" in sys.modules: import PIL.Image if isinstance(value, PIL.Image.Image): return np.asarray(value, **self.np_array_kwargs) + if config.DECORD_AVAILABLE and "decord" in sys.modules: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + from decord import VideoReader + + if isinstance(value, VideoReader): + value._hf_bridge_out = np.asarray + return value return np.asarray(value, **{**default_dtype, **self.np_array_kwargs}) diff --git a/src/datasets/formatting/tf_formatter.py b/src/datasets/formatting/tf_formatter.py index adb15cda381..9f0c06ec82a 100644 --- a/src/datasets/formatting/tf_formatter.py +++ b/src/datasets/formatting/tf_formatter.py @@ -64,11 +64,24 @@ def _tensorize(self, value): default_dtype = {"dtype": tf.int64} elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": tf.float32} - elif config.PIL_AVAILABLE and "PIL" in sys.modules: + + if config.PIL_AVAILABLE and "PIL" in sys.modules: import PIL.Image if isinstance(value, PIL.Image.Image): value = np.asarray(value) + if config.DECORD_AVAILABLE and "decord" in sys.modules: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + from decord import VideoReader + from decord.bridge import to_tensorflow + + if isinstance(value, VideoReader): + value._hf_bridge_out = to_tensorflow + return value return tf.convert_to_tensor(value, **{**default_dtype, **self.tf_tensor_kwargs}) diff --git a/src/datasets/formatting/torch_formatter.py b/src/datasets/formatting/torch_formatter.py index 8efe759a144..051badb0ac4 100644 --- a/src/datasets/formatting/torch_formatter.py +++ b/src/datasets/formatting/torch_formatter.py @@ -66,7 +66,8 @@ def _tensorize(self, value): elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": torch.float32} - elif config.PIL_AVAILABLE and "PIL" in sys.modules: + + if config.PIL_AVAILABLE and "PIL" in sys.modules: import PIL.Image if isinstance(value, PIL.Image.Image): @@ -75,6 +76,14 @@ def _tensorize(self, value): value = value[:, :, np.newaxis] value = value.transpose((2, 0, 1)) + if config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + from decord.bridge import to_torch + + if isinstance(value, VideoReader): + value._hf_bridge_out = to_torch + return value + return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs}) def _recursive_tensorize(self, data_struct): diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py index 51434106fb1..d34f5110204 100644 --- a/src/datasets/io/parquet.py +++ b/src/datasets/io/parquet.py @@ -2,11 +2,10 @@ from typing import BinaryIO, Optional, Union import fsspec -import numpy as np import pyarrow.parquet as pq -from .. import Audio, Dataset, Features, Image, NamedSplit, Value, config -from ..features.features import FeatureType, _visit +from .. import Dataset, Features, NamedSplit, config +from ..arrow_writer import get_writer_batch_size from ..formatting import query_table from ..packaged_modules import _PACKAGED_DATASETS_MODULES from ..packaged_modules.parquet.parquet import Parquet @@ -15,41 +14,6 @@ from .abc import AbstractDatasetReader -def get_writer_batch_size(features: Features) -> Optional[int]: - """ - Get the writer_batch_size that defines the maximum row group size in the parquet files. - The default in `datasets` is 1,000 but we lower it to 100 for image datasets. - This allows to optimize random access to parquet file, since accessing 1 row requires - to read its entire row group. - - This can be improved to get optimized size for querying/iterating - but at least it matches the dataset viewer expectations on HF. - - Args: - ds_config_info (`datasets.info.DatasetInfo`): - Dataset info from `datasets`. - Returns: - writer_batch_size (`Optional[int]`): - Writer batch size to pass to a dataset builder. - If `None`, then it will use the `datasets` default. - """ - - batch_size = np.inf - - def set_batch_size(feature: FeatureType) -> None: - nonlocal batch_size - if isinstance(feature, Image): - batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS) - elif isinstance(feature, Audio): - batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS) - elif isinstance(feature, Value) and feature.dtype == "binary": - batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS) - - _visit(features, set_batch_size) - - return None if batch_size is np.inf else batch_size - - class ParquetDatasetReader(AbstractDatasetReader): def __init__( self, diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index d1b54131b61..e38244383e2 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -141,16 +141,23 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples """ raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet") - def shard_data_sources(self, worker_id: int, num_workers: int) -> "_BaseExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet") - def split_shard_indices_by_worker(self, worker_id: int, num_workers: int) -> List[int]: - return list(range(worker_id, self.n_shards, num_workers)) + def split_shard_indices_by_worker(self, num_shards: int, index: int, contiguous=True) -> List[int]: + if contiguous: + div = self.num_shards // num_shards + mod = self.num_shards % num_shards + start = div * index + min(index, mod) + end = start + div + (1 if index < mod else 0) + return list(range(start, end)) + else: + return list(range(index, self.num_shards, num_shards)) @property - def n_shards(self) -> int: - raise NotImplementedError(f"{type(self)} doesn't implement n_shards yet") + def num_shards(self) -> int: + raise NotImplementedError(f"{type(self)} doesn't implement num_shards yet") def _init_state_dict(self) -> dict: raise NotImplementedError(f"{type(self)} doesn't implement _init_state_dict yet") @@ -187,7 +194,7 @@ def _init_state_dict(self) -> dict: def __iter__(self): shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None): + for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 for key_example in islice(self.generate_examples_fn(**gen_kwags), shard_example_idx_start, None): if self._state_dict: @@ -200,15 +207,15 @@ def __iter__(self): def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable": return ShuffledDataSourcesExamplesIterable(self.generate_examples_fn, self.kwargs, generator) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "ExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" - gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards) - shard_indices = self.split_shard_indices_by_worker(worker_id, num_workers) + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) + shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) return ExamplesIterable(self.generate_examples_fn, requested_gen_kwargs) @property - def n_shards(self) -> int: + def num_shards(self) -> int: return _number_of_shards_in_gen_kwargs(self.kwargs) @@ -229,7 +236,7 @@ def __iter__(self): kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None + _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None ): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 for key_example in islice(self.generate_examples_fn(**gen_kwags), shard_example_idx_start, None): @@ -240,12 +247,12 @@ def __iter__(self): self._state_dict["shard_idx"] += 1 self._state_dict["shard_example_idx"] = 0 - def shard_data_sources(self, worker_id: int, num_workers: int) -> "ExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": """Keep only the requested shard.""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) return ExamplesIterable(self.generate_examples_fn, kwargs_with_shuffled_shards).shard_data_sources( - worker_id, num_workers + num_shards, index, contiguous=contiguous ) @@ -266,7 +273,7 @@ def _init_state_dict(self) -> dict: def __iter__(self): formatter = PythonFormatter() shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None): + for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 for key, pa_table in self.generate_tables_fn(**gen_kwags): @@ -287,7 +294,7 @@ def __iter__(self): def _iter_arrow(self): shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 - for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None): + for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 for key, pa_table in self.generate_tables_fn(**gen_kwags): @@ -304,15 +311,15 @@ def _iter_arrow(self): def shuffle_data_sources(self, generator: np.random.Generator) -> "ArrowExamplesIterable": return ShuffledDataSourcesArrowExamplesIterable(self.generate_tables_fn, self.kwargs, generator) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "ArrowExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" - gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards) - shard_indices = self.split_shard_indices_by_worker(worker_id, num_workers) + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) + shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) return ArrowExamplesIterable(self.generate_tables_fn, requested_gen_kwargs) @property - def n_shards(self) -> int: + def num_shards(self) -> int: return _number_of_shards_in_gen_kwargs(self.kwargs) @@ -337,7 +344,7 @@ def __iter__(self): formatter = PythonFormatter() shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None + _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None ): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 @@ -362,7 +369,7 @@ def _iter_arrow(self): kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 for gen_kwags in islice( - _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None + _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None ): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 shard_example_idx = 0 @@ -377,12 +384,12 @@ def _iter_arrow(self): self._state_dict["shard_idx"] += 1 self._state_dict["shard_example_idx"] = 0 - def shard_data_sources(self, worker_id: int, num_workers: int) -> "ArrowExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": """Keep only the requested shard.""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) return ArrowExamplesIterable(self.generate_tables_fn, kwargs_with_shuffled_shards).shard_data_sources( - worker_id, num_workers + num_shards, index, contiguous=contiguous ) @@ -505,14 +512,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RebatchedArro self.ex_iterable.shuffle_data_sources(generator), self.batch_size, self.drop_last_batch ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "RebatchedArrowExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RebatchedArrowExamplesIterable": return RebatchedArrowExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), self.batch_size, self.drop_last_batch + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + self.batch_size, + self.drop_last_batch, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class SelectColumnsIterable(_BaseExamplesIterable): @@ -546,12 +555,14 @@ def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumnsIterable": return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "SelectColumnsIterable": - return SelectColumnsIterable(self.ex_iterable.shard_data_sources(worker_id, num_workers), self.column_names) + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SelectColumnsIterable": + return SelectColumnsIterable( + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.column_names + ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class StepExamplesIterable(_BaseExamplesIterable): @@ -584,14 +595,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "StepExamplesI self.ex_iterable.shuffle_data_sources(generator), step=self.step, offset=self.offset ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "StepExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "StepExamplesIterable": return StepExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), step=self.step, offset=self.offset + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + step=self.step, + offset=self.offset, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable): @@ -679,13 +692,15 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "CyclingMultiS return CyclingMultiSourcesExamplesIterable(ex_iterables, self.stopping_strategy) @property - def n_shards(self) -> int: - return min(ex_iterable.n_shards for ex_iterable in self.ex_iterables) + def num_shards(self) -> int: + return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "CyclingMultiSourcesExamplesIterable": + def shard_data_sources( + self, num_shards: int, index: int, contiguous=True + ) -> "CyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return CyclingMultiSourcesExamplesIterable( - [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables], + [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables], stopping_strategy=self.stopping_strategy, ) @@ -748,15 +763,15 @@ def shuffle_data_sources( return VerticallyConcatenatedMultiSourcesExamplesIterable(ex_iterables) @property - def n_shards(self) -> int: - return min(ex_iterable.n_shards for ex_iterable in self.ex_iterables) + def num_shards(self) -> int: + return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) def shard_data_sources( - self, worker_id: int, num_workers: int + self, num_shards: int, index: int, contiguous=True ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return VerticallyConcatenatedMultiSourcesExamplesIterable( - [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables] + [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] ) @@ -829,15 +844,15 @@ def shuffle_data_sources( return self @property - def n_shards(self) -> int: + def num_shards(self) -> int: return 1 def shard_data_sources( - self, worker_id: int, num_workers: int + self, num_shards: int, index: int, contiguous=True ) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return HorizontallyConcatenatedMultiSourcesExamplesIterable( - [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables] + [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables] ) @@ -907,10 +922,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RandomlyCycli stopping_strategy=self.stopping_strategy, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "RandomlyCyclingMultiSourcesExamplesIterable": + def shard_data_sources( + self, num_shards: int, index: int, contiguous=True + ) -> "RandomlyCyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return RandomlyCyclingMultiSourcesExamplesIterable( - [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables], + [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables], self.generator, self.probabilities, self.stopping_strategy, @@ -1161,10 +1178,10 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample formatting=self.formatting, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "MappedExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MappedExamplesIterable": """Keep only the requested shard.""" return MappedExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), function=self.function, with_indices=self.with_indices, input_columns=self.input_columns, @@ -1177,8 +1194,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "MappedExample ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class FilteredExamplesIterable(_BaseExamplesIterable): @@ -1381,10 +1398,10 @@ def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable batch_size=self.batch_size, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "FilteredExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FilteredExamplesIterable": """Keep only the requested shard.""" return FilteredExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), function=self.function, with_indices=self.with_indices, input_columns=self.input_columns, @@ -1393,8 +1410,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "FilteredExamp ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class BufferShuffledExamplesIterable(_BaseExamplesIterable): @@ -1451,17 +1468,17 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffle self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "BufferShuffledExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable": """Keep only the requested shard.""" return BufferShuffledExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), buffer_size=self.buffer_size, generator=self.generator, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class SkipExamplesIterable(_BaseExamplesIterable): @@ -1514,12 +1531,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SkipExamplesI split_when_sharding=self.split_when_sharding, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "SkipExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SkipExamplesIterable": """Keep only the requested shard.""" if self.split_when_sharding: return SkipExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), - n=self.split_number(self.n, num_workers)[worker_id], + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + n=self.split_number(self.n, num_shards)[index], block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, split_when_sharding=self.split_when_sharding, ) @@ -1527,8 +1544,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "SkipExamplesI return self @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class TakeExamplesIterable(_BaseExamplesIterable): @@ -1582,26 +1599,26 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TakeExamplesI split_when_sharding=self.split_when_sharding, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "TakeExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TakeExamplesIterable": """Keep only the requested shard.""" if self.split_when_sharding: return TakeExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), - n=self.split_number(self.n, num_workers)[worker_id], + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), + n=self.split_number(self.n, num_shards)[index], block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, split_when_sharding=self.split_when_sharding, ) else: return TakeExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), n=self.n, block_sources_order_when_shuffling=self.block_sources_order_when_shuffling, split_when_sharding=self.split_when_sharding, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards def _apply_feature_types_on_example( @@ -1690,17 +1707,17 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TypedExamples token_per_repo_id=self.token_per_repo_id, ) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "TypedExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TypedExamplesIterable": """Keep only the requested shard.""" return TypedExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), features=self.features, token_per_repo_id=self.token_per_repo_id, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards @dataclass @@ -1885,7 +1902,7 @@ def load_state_dict(self, state_dict: dict) -> None: self._starting_state_dict = state_dict def __repr__(self): - return f"IterableDataset({{\n features: {list(self._info.features.keys()) if self._info.features is not None else 'Unknown'},\n n_shards: {self.n_shards}\n}})" + return f"IterableDataset({{\n features: {list(self._info.features.keys()) if self._info.features is not None else 'Unknown'},\n num_shards: {self.num_shards}\n}})" def __getstate__(self): return self.__dict__ @@ -1916,10 +1933,14 @@ def _effective_generator(self): raise ValueError("This dataset is not shuffled") @property - def n_shards(self) -> int: - if self._distributed and self._ex_iterable.n_shards % self._distributed.world_size == 0: - return self._ex_iterable.n_shards // self._distributed.world_size - return self._ex_iterable.n_shards + def num_shards(self) -> int: + if self._distributed and self._ex_iterable.num_shards % self._distributed.world_size == 0: + return self._ex_iterable.num_shards // self._distributed.world_size + return self._ex_iterable.num_shards + + @property + def n_shards(self) -> int: # backward compatibility + return self.num_shards def _iter_pytorch(self): ex_iterable = self._prepare_ex_iterable_for_iteration() @@ -1930,24 +1951,28 @@ def _iter_pytorch(self): import torch.utils.data worker_info = torch.utils.data.get_worker_info() - if self._is_main_process() and ex_iterable.n_shards < worker_info.num_workers: + if self._is_main_process() and ex_iterable.num_shards < worker_info.num_workers: logger.warning( - f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={ex_iterable.n_shards}). " - f"Stopping {worker_info.num_workers - ex_iterable.n_shards} dataloader workers." + f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.num_shards={ex_iterable.num_shards}). " + f"Stopping {worker_info.num_workers - ex_iterable.num_shards} dataloader workers." ) logger.info( f"To parallelize data loading, we give each process some shards (or data sources) to process. " - f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={ex_iterable.n_shards}. " - f"To enable more parallelism, please split the dataset in more files than {ex_iterable.n_shards}." + f"Therefore it's unnecessary to have a number of workers greater than dataset.num_shards={ex_iterable.num_shards}. " + f"To enable more parallelism, please split the dataset in more files than {ex_iterable.num_shards}." ) # split workload _log_prefix = f"node#{self._distributed.rank} " if self._distributed else "" - shards_indices = ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers) + shards_indices = ex_iterable.split_shard_indices_by_worker( + num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False + ) if shards_indices: logger.debug( - f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards." + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.num_shards} shards." + ) + ex_iterable = ex_iterable.shard_data_sources( + num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) - ex_iterable = ex_iterable.shard_data_sources(worker_id=worker_info.id, num_workers=worker_info.num_workers) self._state_dict = ex_iterable._init_state_dict() if self._starting_state_dict: ex_iterable.load_state_dict(self._starting_state_dict) @@ -1978,11 +2003,11 @@ def _iter_pytorch(self): ) yield format_dict(example) if format_dict else example logger.debug( - f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.n_shards} shards." + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.num_shards} shards." ) else: logger.debug( - f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.n_shards}<{worker_info.num_workers})." + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.num_shards}<{worker_info.num_workers})." ) def _is_main_process(self): @@ -2012,14 +2037,14 @@ def _prepare_ex_iterable_for_iteration( if self._distributed: rank = self._distributed.rank world_size = self._distributed.world_size - if ex_iterable.n_shards % world_size == 0: + if ex_iterable.num_shards % world_size == 0: if self._is_main_process(): - n_shards_per_node = ex_iterable.n_shards // world_size - plural = "s" if n_shards_per_node > 1 else "" + num_shards_per_node = ex_iterable.num_shards // world_size + plural = "s" if num_shards_per_node > 1 else "" logger.info( - f"Assigning {n_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node." + f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node." ) - ex_iterable = ex_iterable.shard_data_sources(rank, world_size) + ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False) else: if self._is_main_process(): logger.info( @@ -2028,7 +2053,7 @@ def _prepare_ex_iterable_for_iteration( logger.info( f"It is more optimized to distribute the dataset shards (or data sources) across nodes. " f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. " - f"The current dataset has {ex_iterable.n_shards} which is not a factor of {world_size}" + f"The current dataset has {ex_iterable.num_shards} which is not a factor of {world_size}" ) ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank) @@ -2635,6 +2660,63 @@ def take(self, n: int) -> "IterableDataset": token_per_repo_id=self._token_per_repo_id, ) + def shard( + self, + num_shards: int, + index: int, + contiguous: bool = True, + ) -> "Dataset": + """Return the `index`-nth shard from dataset split into `num_shards` pieces. + + This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks, + so it can be easily concatenated back together after processing. If `dataset.num_shards % n == l`, then the + first `l` datasets each have `(dataset.num_shards // n) + 1` shards, and the remaining datasets have `(dataset.num_shards // n)` shards. + `datasets.concatenate_datasets([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original. + In particular, `dataset.shard(dataset.num_shards, i)` returns a dataset with 1 shard. + + Note: n should be less or equal to the number of shards in the dataset `dataset.num_shards`. + + On the other hand, `dataset.shard(n, i, contiguous=False)` contains all the shards of the dataset whose index mod `n = i`. + + Be sure to shard before using any randomizing operator (such as `shuffle`). + It is best if the shard operator is used early in the dataset pipeline. + + Args: + num_shards (`int`): + How many shards to split the dataset into. + index (`int`): + Which shard to select and return. + contiguous: (`bool`, defaults to `True`): + Whether to select contiguous blocks of indices for shards. + + Example: + + ```py + >>> from datasets import load_dataset + >>> ds = load_dataset("amazon_polarity", split="train", streaming=True) + >>> ds + Dataset({ + features: ['label', 'title', 'content'], + num_shards: 4 + }) + >>> ds.shard(num_shards=2, index=0) + Dataset({ + features: ['label', 'title', 'content'], + num_shards: 2 + }) + ``` + """ + ex_iterable = self._ex_iterable.shard_data_sources(num_shards=num_shards, index=index, contiguous=contiguous) + return IterableDataset( + ex_iterable=ex_iterable, + info=self._info.copy(), + split=self._split, + formatting=self._formatting, + shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), + token_per_repo_id=self._token_per_repo_id, + ) + @property def column_names(self) -> Optional[List[str]]: """Names of the columns in the dataset. @@ -3079,7 +3161,7 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s """ Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`. - If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), + If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. diff --git a/src/datasets/load.py b/src/datasets/load.py index 0faf2fd5cb5..ebdafafcd5f 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1609,7 +1609,7 @@ def dataset_module_factory( e.__cause__, ( OfflineModeIsEnabled, - requests.exceptions.ConnectTimeout, + requests.exceptions.Timeout, requests.exceptions.ConnectionError, ), ): @@ -1624,7 +1624,7 @@ def dataset_module_factory( ).sha except ( OfflineModeIsEnabled, - requests.exceptions.ConnectTimeout, + requests.exceptions.Timeout, requests.exceptions.ConnectionError, ) as e: raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 6a23170db5e..94d806edf45 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -14,7 +14,9 @@ from .parquet import parquet from .sql import sql from .text import text +from .videofolder import videofolder from .webdataset import webdataset +from .xml import xml def _hash_python_lines(lines: List[str]) -> str: @@ -40,7 +42,9 @@ def _hash_python_lines(lines: List[str]) -> str: "text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())), "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), + "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), + "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), } # get importable module names and hash for caching @@ -69,12 +73,15 @@ def _hash_python_lines(lines: List[str]) -> str: ".arrow": ("arrow", {}), ".txt": ("text", {}), ".tar": ("webdataset", {}), + ".xml": ("xml", {}), } _EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) -_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder"} +_EXTENSION_TO_MODULE.update({ext: ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder", "videofolder"} # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: Dict[str, List[str]] = {} diff --git a/src/datasets/packaged_modules/spark/spark.py b/src/datasets/packaged_modules/spark/spark.py index c21cb3dd981..00b97f99c9f 100644 --- a/src/datasets/packaged_modules/spark/spark.py +++ b/src/datasets/packaged_modules/spark/spark.py @@ -100,12 +100,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SparkExamples generator.shuffle(partition_order) return SparkExamplesIterable(self.df, partition_order=partition_order) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "SparkExamplesIterable": - partition_order = self.split_shard_indices_by_worker(worker_id, num_workers) + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SparkExamplesIterable": + partition_order = self.split_shard_indices_by_worker(num_shards=num_shards, index=index, contiguous=contiguous) return SparkExamplesIterable(self.df, partition_order=partition_order) @property - def n_shards(self) -> int: + def num_shards(self) -> int: return len(self.partition_order) diff --git a/src/datasets/packaged_modules/videofolder/__init__.py b/src/datasets/packaged_modules/videofolder/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/videofolder/videofolder.py b/src/datasets/packaged_modules/videofolder/videofolder.py new file mode 100644 index 00000000000..7ce5bcf5655 --- /dev/null +++ b/src/datasets/packaged_modules/videofolder/videofolder.py @@ -0,0 +1,36 @@ +from typing import List + +import datasets + +from ..folder_based_builder import folder_based_builder + + +logger = datasets.utils.logging.get_logger(__name__) + + +class VideoFolderConfig(folder_based_builder.FolderBasedBuilderConfig): + """BuilderConfig for ImageFolder.""" + + drop_labels: bool = None + drop_metadata: bool = None + + def __post_init__(self): + super().__post_init__() + + +class VideoFolder(folder_based_builder.FolderBasedBuilder): + BASE_FEATURE = datasets.Video + BASE_COLUMN_NAME = "video" + BUILDER_CONFIG_CLASS = VideoFolderConfig + EXTENSIONS: List[str] # definition at the bottom of the script + + +# TODO: initial list, we should check the compatibility of other formats +VIDEO_EXTENSIONS = [ + ".mkv", + ".mp4", + ".avi", + ".mpeg", + ".mov", +] +VideoFolder.EXTENSIONS = VIDEO_EXTENSIONS diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index c04f3ba4639..0768437b36a 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -20,6 +20,7 @@ class WebDataset(datasets.GeneratorBasedBuilder): DEFAULT_WRITER_BATCH_SIZE = 100 IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script AUDIO_EXTENSIONS: List[str] # definition at the bottom of the script + VIDEO_EXTENSIONS: List[str] # definition at the bottom of the script DECODERS: Dict[str, Callable[[Any], Any]] # definition at the bottom of the script NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5 @@ -97,6 +98,11 @@ def _split_generators(self, dl_manager): extension = field_name.rsplit(".", 1)[-1] if extension in self.AUDIO_EXTENSIONS: features[field_name] = datasets.Audio() + # Set Video types + for field_name in first_examples[0]: + extension = field_name.rsplit(".", 1)[-1] + if extension in self.VIDEO_EXTENSIONS: + features[field_name] = datasets.Video() self.info.features = features return splits @@ -259,6 +265,17 @@ def base_plus_ext(path): WebDataset.AUDIO_EXTENSIONS = AUDIO_EXTENSIONS +# TODO: initial list, we should check the compatibility of other formats +VIDEO_EXTENSIONS = [ + ".mkv", + ".mp4", + ".avi", + ".mpeg", + ".mov", +] +WebDataset.VIDEO_EXTENSIONS = VIDEO_EXTENSIONS + + def text_loads(data: bytes): return data.decode("utf-8") diff --git a/src/datasets/packaged_modules/xml/__init__.py b/src/datasets/packaged_modules/xml/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/xml/xml.py b/src/datasets/packaged_modules/xml/xml.py new file mode 100644 index 00000000000..d5009b4dd6a --- /dev/null +++ b/src/datasets/packaged_modules/xml/xml.py @@ -0,0 +1,68 @@ +import itertools +from dataclasses import dataclass +from typing import Optional + +import pyarrow as pa + +import datasets +from datasets.features.features import require_storage_cast +from datasets.table import table_cast + + +logger = datasets.utils.logging.get_logger(__name__) + + +@dataclass +class XmlConfig(datasets.BuilderConfig): + """BuilderConfig for xml files.""" + + features: Optional[datasets.Features] = None + encoding: str = "utf-8" + encoding_errors: Optional[str] = None + + +class Xml(datasets.ArrowBasedBuilder): + BUILDER_CONFIG_CLASS = XmlConfig + + def _info(self): + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. + + If str or List[str], then the dataset returns only the 'train' split. + If dict, then keys should be from the `datasets.Split` enum. + """ + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + dl_manager.download_config.extract_on_the_fly = True + data_files = dl_manager.download_and_extract(self.config.data_files) + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + files = [dl_manager.iter_files(file) for file in files] + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + return splits + + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + schema = self.config.features.arrow_schema + if all(not require_storage_cast(feature) for feature in self.config.features.values()): + # cheaper cast + pa_table = pa_table.cast(schema) + else: + # more expensive cast; allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, schema) + return pa_table + else: + return pa_table.cast(pa.schema({"xml": pa.string()})) + + def _generate_tables(self, files): + pa_table_names = list(self.config.features) if self.config.features is not None else ["xml"] + for file_idx, file in enumerate(itertools.chain.from_iterable(files)): + # open in text mode, by default translates universal newlines ("\n", "\r\n" and "\r") into "\n" + with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f: + xml = f.read() + pa_table = pa.Table.from_arrays([pa.array([xml])], names=pa_table_names) + yield file_idx, self._cast_table(pa_table) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index e44b1ce12bc..ab2d11bc71a 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -828,8 +828,8 @@ def read_with_retries(*args, **kwargs): except ( aiohttp.client_exceptions.ClientError, asyncio.TimeoutError, - requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError, + requests.exceptions.Timeout, ) as err: disconnect_err = err logger.warning( diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py index 2de35a943e7..154f4d61072 100644 --- a/src/datasets/utils/tf_utils.py +++ b/src/datasets/utils/tf_utils.py @@ -278,7 +278,7 @@ def __init__( self.cols_to_retain = cols_to_retain self.collate_fn = collate_fn self.collate_fn_args = collate_fn_args - self.string_columns = [col for col, dtype in columns_to_np_types.items() if dtype in (np.unicode_, np.str_)] + self.string_columns = [col for col, dtype in columns_to_np_types.items() if dtype is np.str_] # Strings will be converted to arrays of single unicode chars, so that we can have a constant itemsize self.columns_to_np_types = { col: dtype if col not in self.string_columns else np.dtype("U1") diff --git a/tests/features/data/test_video_66x50.mov b/tests/features/data/test_video_66x50.mov new file mode 100644 index 00000000000..a55dcaa8f7b Binary files /dev/null and b/tests/features/data/test_video_66x50.mov differ diff --git a/tests/features/test_video.py b/tests/features/test_video.py new file mode 100644 index 00000000000..f4c9a8d830b --- /dev/null +++ b/tests/features/test_video.py @@ -0,0 +1,92 @@ +import pytest + +from datasets import Dataset, Features, Video + +from ..utils import require_decord + + +@require_decord +@pytest.mark.parametrize( + "build_example", + [ + lambda video_path: video_path, + lambda video_path: open(video_path, "rb").read(), + lambda video_path: {"path": video_path}, + lambda video_path: {"path": video_path, "bytes": None}, + lambda video_path: {"path": video_path, "bytes": open(video_path, "rb").read()}, + lambda video_path: {"path": None, "bytes": open(video_path, "rb").read()}, + lambda video_path: {"bytes": open(video_path, "rb").read()}, + ], +) +def test_video_feature_encode_example(shared_datadir, build_example): + from decord import VideoReader + + video_path = str(shared_datadir / "test_video_66x50.mov") + video = Video() + encoded_example = video.encode_example(build_example(video_path)) + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["bytes"] is not None or encoded_example["path"] is not None + decoded_example = video.decode_example(encoded_example) + assert isinstance(decoded_example, VideoReader) + + +@require_decord +def test_dataset_with_video_feature(shared_datadir): + from decord import VideoReader + from decord.ndarray import NDArray + + video_path = str(shared_datadir / "test_video_66x50.mov") + data = {"video": [video_path]} + features = Features({"video": Video()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"video"} + assert isinstance(item["video"], VideoReader) + assert item["video"][0].shape == (50, 66, 3) + assert isinstance(item["video"][0], NDArray) + batch = dset[:1] + assert len(batch) == 1 + assert batch.keys() == {"video"} + assert isinstance(batch["video"], list) and all(isinstance(item, VideoReader) for item in batch["video"]) + assert batch["video"][0][0].shape == (50, 66, 3) + assert isinstance(batch["video"][0][0], NDArray) + column = dset["video"] + assert len(column) == 1 + assert isinstance(column, list) and all(isinstance(item, VideoReader) for item in column) + assert column[0][0].shape == (50, 66, 3) + assert isinstance(column[0][0], NDArray) + + # from bytes + with open(video_path, "rb") as f: + data = {"video": [f.read()]} + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"video"} + assert isinstance(item["video"], VideoReader) + assert item["video"][0].shape == (50, 66, 3) + assert isinstance(item["video"][0], NDArray) + + +@require_decord +def test_dataset_with_video_map_and_formatted(shared_datadir): + import numpy as np + from decord import VideoReader + + video_path = str(shared_datadir / "test_video_66x50.mov") + data = {"video": [video_path]} + features = Features({"video": Video()}) + dset = Dataset.from_dict(data, features=features) + dset = dset.map(lambda x: x).with_format("numpy") + example = dset[0] + assert isinstance(example["video"], VideoReader) + assert isinstance(example["video"][0], np.ndarray) + + # from bytes + with open(video_path, "rb") as f: + data = {"video": [f.read()]} + dset = Dataset.from_dict(data, features=features) + dset = dset.map(lambda x: x).with_format("numpy") + example = dset[0] + assert isinstance(example["video"], VideoReader) + assert isinstance(example["video"][0], np.ndarray) diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py index c91bdd571ea..89406c6a8dd 100644 --- a/tests/packaged_modules/test_spark.py +++ b/tests/packaged_modules/test_spark.py @@ -72,7 +72,7 @@ def test_spark_examples_iterable(): spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate() df = spark.range(10).repartition(1) it = SparkExamplesIterable(df) - assert it.n_shards == 1 + assert it.num_shards == 1 for i, (row_id, row_dict) in enumerate(it): assert row_id == f"0_{i}" assert row_dict == {"id": i} @@ -89,7 +89,7 @@ def test_spark_examples_iterable_shuffle(): expected_row_ids_and_row_dicts = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [2, 1, 0]) shuffled_it = SparkExamplesIterable(df).shuffle_data_sources(generator_mock) - assert shuffled_it.n_shards == 3 + assert shuffled_it.num_shards == 3 for i, (row_id, row_dict) in enumerate(shuffled_it): expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts[i] assert row_id == expected_row_id @@ -103,8 +103,8 @@ def test_spark_examples_iterable_shard(): df = spark.range(20).repartition(4) # Partitions 0 and 2 - shard_it_1 = SparkExamplesIterable(df).shard_data_sources(worker_id=0, num_workers=2) - assert shard_it_1.n_shards == 2 + shard_it_1 = SparkExamplesIterable(df).shard_data_sources(index=0, num_shards=2, contiguous=False) + assert shard_it_1.num_shards == 2 expected_row_ids_and_row_dicts_1 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [0, 2]) for i, (row_id, row_dict) in enumerate(shard_it_1): expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_1[i] @@ -112,8 +112,8 @@ def test_spark_examples_iterable_shard(): assert row_dict == expected_row_dict # Partitions 1 and 3 - shard_it_2 = SparkExamplesIterable(df).shard_data_sources(worker_id=1, num_workers=2) - assert shard_it_2.n_shards == 2 + shard_it_2 = SparkExamplesIterable(df).shard_data_sources(index=1, num_shards=2, contiguous=False) + assert shard_it_2.num_shards == 2 expected_row_ids_and_row_dicts_2 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [1, 3]) for i, (row_id, row_dict) in enumerate(shard_it_2): expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_2[i] diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ffa048644e2..1e08862031b 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2630,10 +2630,12 @@ def test_shard(self, in_memory): tmp_file = os.path.join(tmp_dir, "test.arrow") with dset.select(range(10), indices_cache_file_name=tmp_file) as dset: self.assertEqual(len(dset), 10) - # Shard + # Shard non-contiguous tmp_file_1 = os.path.join(tmp_dir, "test_1.arrow") fingerprint = dset._fingerprint - with dset.shard(num_shards=8, index=1, indices_cache_file_name=tmp_file_1) as dset_sharded: + with dset.shard( + num_shards=8, index=1, contiguous=False, indices_cache_file_name=tmp_file_1 + ) as dset_sharded: self.assertEqual(2, len(dset_sharded)) self.assertEqual(["my_name-train_1", "my_name-train_9"], dset_sharded["filename"]) self.assertDictEqual(dset.features, Features({"filename": Value("string")})) @@ -4268,7 +4270,7 @@ def test_dataset_to_iterable_dataset(dataset: Dataset): assert isinstance(iterable_dataset, IterableDataset) assert list(iterable_dataset) == list(dataset) assert iterable_dataset.features == dataset.features - assert iterable_dataset.n_shards == 3 + assert iterable_dataset.num_shards == 3 with pytest.raises(ValueError): dataset.to_iterable_dataset(num_shards=len(dataset) + 1) with pytest.raises(NotImplementedError): diff --git a/tests/test_distributed.py b/tests/test_distributed.py index b8e0f56b180..65d2130f753 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -46,11 +46,11 @@ def gen(shards): gen_kwargs = {"shards": [f"shard_{shard_idx}.txt" for shard_idx in range(num_shards)]} full_ds = IterableDataset.from_generator(gen, gen_kwargs=gen_kwargs) full_size = len(list(full_ds)) - assert full_ds.n_shards == world_size * shards_per_node + assert full_ds.num_shards == world_size * shards_per_node datasets_per_rank = [ split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size) ] - assert [ds.n_shards for ds in datasets_per_rank] == [shards_per_node] * world_size + assert [ds.num_shards for ds in datasets_per_rank] == [shards_per_node] * world_size assert sum(len(list(ds)) for ds in datasets_per_rank) == full_size assert len({tuple(x.values()) for ds in datasets_per_rank for x in ds}) == full_size diff --git a/tests/test_hub.py b/tests/test_hub.py index 9485fe83a71..13c496e0f6f 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -5,9 +5,10 @@ import pytest from huggingface_hub import CommitOperationAdd, CommitOperationDelete +from packaging import version import datasets -from datasets.config import METADATA_CONFIGS_FIELD +from datasets.config import METADATA_CONFIGS_FIELD, PYARROW_VERSION from datasets.hub import convert_to_parquet, delete_from_hub from datasets.utils.hub import hf_dataset_url @@ -83,7 +84,7 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_ - name: train num_bytes: 55 num_examples: 5 - download_size: 790 + download_size: 726 dataset_size: 55 {METADATA_CONFIGS_FIELD}: - config_name: first @@ -104,7 +105,7 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_ - name: train num_bytes: 60 num_examples: 5 - download_size: 798 + download_size: 732 dataset_size: 60 {METADATA_CONFIGS_FIELD}: - config_name: second @@ -114,6 +115,9 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_ --- """), ] + if PYARROW_VERSION < version.parse("18.0.0"): + expected_readmes[0] = expected_readmes[0].replace("download_size: 726", "download_size: 790") + expected_readmes[1] = expected_readmes[1].replace("download_size: 732", "download_size: 798") for call_args, expected_commit_message, expected_create_pr, expected_readme, expected_parquet_path_in_repo in zip( mock_create_commit.call_args_list, ["Convert dataset to Parquet", "Add 'second' config data files"], diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 232652f1fa3..44d3d24fa23 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1277,7 +1277,7 @@ def gen(shard_names): shard_names = [f"data{shard_idx}.txt" for shard_idx in range(4)] dataset = IterableDataset.from_generator(gen, gen_kwargs={"shard_names": shard_names}) assert isinstance(dataset, IterableDataset) - assert dataset.n_shards == len(shard_names) + assert dataset.num_shards == len(shard_names) @require_numpy1_on_windows @@ -1392,11 +1392,11 @@ def test_iterable_dataset_torch_dataloader_parallel(): @require_torch @pytest.mark.filterwarnings("ignore:This DataLoader will create:UserWarning") -@pytest.mark.parametrize("n_shards, num_workers", [(2, 1), (2, 2), (3, 2), (2, 3)]) -def test_sharded_iterable_dataset_torch_dataloader_parallel(n_shards, num_workers): +@pytest.mark.parametrize("num_shards, num_workers", [(2, 1), (2, 2), (3, 2), (2, 3)]) +def test_sharded_iterable_dataset_torch_dataloader_parallel(num_shards, num_workers): from torch.utils.data import DataLoader - ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(n_shards)]}) + ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(num_shards)]}) dataset = IterableDataset(ex_iterable) dataloader = DataLoader(dataset, batch_size=None, num_workers=num_workers) result = list(dataloader) @@ -1681,13 +1681,36 @@ def test_iterable_dataset_take(dataset: IterableDataset, n): assert list(take_dataset) == list(dataset)[:n] +def test_iterable_dataset_shard(): + num_examples = 20 + num_shards = 5 + dataset = Dataset.from_dict({"a": range(num_examples)}).to_iterable_dataset(num_shards=num_shards) + assert sum(dataset.shard(num_shards, i).num_shards for i in range(num_shards)) == dataset.num_shards + assert list(concatenate_datasets([dataset.shard(num_shards, i) for i in range(num_shards)])) == list(dataset) + num_shards = 2 + assert sum(dataset.shard(num_shards, i).num_shards for i in range(num_shards)) == dataset.num_shards + assert list(concatenate_datasets([dataset.shard(num_shards, i) for i in range(num_shards)])) == list(dataset) + assert ( + sum(dataset.shard(num_shards, i, contiguous=False).num_shards for i in range(num_shards)) == dataset.num_shards + ) + assert list( + concatenate_datasets([dataset.shard(num_shards, i, contiguous=False) for i in range(num_shards)]) + ) != list(dataset) + assert sorted( + concatenate_datasets([dataset.shard(num_shards, i, contiguous=False) for i in range(num_shards)]), + key=lambda x: x["a"], + ) == list(dataset) + + @pytest.mark.parametrize("method", ["skip", "take"]) @pytest.mark.parametrize("after_shuffle", [False, True]) @pytest.mark.parametrize("count", [2, 5, 11]) def test_iterable_dataset_skip_or_take_after_shuffle(method, after_shuffle, count): seed = 42 - n, n_shards = 3, 10 - ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]}) + n, num_shards = 3, 10 + ex_iterable = ExamplesIterable( + generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(num_shards)]} + ) dataset = IterableDataset(ex_iterable) shuffled_dataset = dataset if after_shuffle: @@ -1714,9 +1737,11 @@ def test_iterable_dataset_skip_or_take_after_shuffle(method, after_shuffle, coun @pytest.mark.parametrize("after_split_by_node", [False, True]) @pytest.mark.parametrize("count", [2, 5, 11]) def test_iterable_dataset_skip_or_take_after_split_by_node(method, after_split_by_node, count): - n, n_shards = 3, 10 + n, num_shards = 3, 10 rank, world_size = 1, 2 - ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]}) + ex_iterable = ExamplesIterable( + generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(num_shards)]} + ) dataset = IterableDataset(ex_iterable) distributed_dataset = dataset true_distributed_dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) @@ -2114,17 +2139,17 @@ def add_one_numpy(example): assert isinstance(next(dataset.iter(batch_size=3))["id"], list) -@pytest.mark.parametrize("n_shards1, n_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)]) -def test_interleave_dataset_with_sharding(n_shards1, n_shards2, num_workers): +@pytest.mark.parametrize("num_shards1, num_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)]) +def test_interleave_dataset_with_sharding(num_shards1, num_shards2, num_workers): from torch.utils.data import DataLoader - ex_iterable1 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-1.txt" for i in range(n_shards1)]}) + ex_iterable1 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-1.txt" for i in range(num_shards1)]}) dataset1 = IterableDataset(ex_iterable1).with_format("torch") - ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(n_shards2)]}) + ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(num_shards2)]}) dataset2 = IterableDataset(ex_iterable2).with_format("torch") dataset_merged = interleave_datasets([dataset1, dataset2], stopping_strategy="first_exhausted") - assert dataset_merged.n_shards == min(n_shards1, n_shards2) + assert dataset_merged.num_shards == min(num_shards1, num_shards2) dataloader = DataLoader(dataset_merged, batch_size=None, num_workers=num_workers) result = list(dataloader) expected_length = 2 * min( diff --git a/tests/test_offline_util.py b/tests/test_offline_util.py index 7bea143df4b..22a372205ad 100644 --- a/tests/test_offline_util.py +++ b/tests/test_offline_util.py @@ -16,11 +16,11 @@ def test_offline_with_timeout(): with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT): with pytest.raises(RequestWouldHangIndefinitelyError): requests.request("GET", "https://huggingface.co") - with pytest.raises(requests.exceptions.ConnectTimeout): + with pytest.raises(requests.exceptions.Timeout): requests.request("GET", "https://huggingface.co", timeout=1.0) # old versions of `huggingface_hub` don't have timeouts by default and don't allow to set timeouts in HfFileSystem if version.parse(huggingface_hub.__version__) >= version.parse("0.23.0"): - with pytest.raises(requests.exceptions.ConnectTimeout), NamedTemporaryFile() as temp_file: + with pytest.raises(requests.exceptions.Timeout), NamedTemporaryFile() as temp_file: fsspec_get("hf://dummy", temp_file=temp_file) diff --git a/tests/utils.py b/tests/utils.py index e19740a2a12..08497e1eae7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -178,6 +178,18 @@ def require_pil(test_case): return test_case +def require_decord(test_case): + """ + Decorator marking a test that requires decord. + + These tests are skipped when decord isn't installed. + + """ + if not config.DECORD_AVAILABLE: + test_case = unittest.skip("test requires decord")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.