diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 690c3a19601..2e3728ef83a 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -72,6 +72,10 @@
title: Semantic segmentation
- local: object_detection
title: Object detection
+ - local: video_load
+ title: Load video data
+ - local: video_dataset
+ title: Create a video dataset
title: "Vision"
- sections:
- local: nlp_load
diff --git a/docs/source/about_mapstyle_vs_iterable.mdx b/docs/source/about_mapstyle_vs_iterable.mdx
index 1e9fa279e11..f794eea5714 100644
--- a/docs/source/about_mapstyle_vs_iterable.mdx
+++ b/docs/source/about_mapstyle_vs_iterable.mdx
@@ -139,12 +139,12 @@ But using a shuffle buffer is not enough to provide a satisfactory shuffling for
```python
# Stream from the internet
my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True)
-my_iterable_dataset.n_shards # 39
+my_iterable_dataset.num_shards # 39
# Stream from local files
data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
-my_iterable_dataset.n_shards # 1024
+my_iterable_dataset.num_shards # 1024
# From a generator function
def my_generator(n, sources):
@@ -154,7 +154,7 @@ def my_generator(n, sources):
gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs)
-my_iterable_dataset.n_shards # 1024
+my_iterable_dataset.num_shards # 1024
```
## Speed differences
@@ -242,5 +242,5 @@ my_iterable_dataset = my_dataset.to_iterable_dataset()
If you want to shuffle your dataset or [use it with a PyTorch DataLoader](./use_with_pytorch#stream-data), we recommend generating a sharded [`IterableDataset`]:
```python
my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024)
-my_iterable_dataset.n_shards # 1024
+my_iterable_dataset.num_shards # 1024
```
diff --git a/docs/source/how_to.md b/docs/source/how_to.md
index 7e6cf8f719e..223a7c2c4c0 100644
--- a/docs/source/how_to.md
+++ b/docs/source/how_to.md
@@ -14,7 +14,7 @@ The guides are organized into six sections:
- General usage: Functions for general dataset loading and processing. The functions shown in this section are applicable across all dataset modalities.
- Audio: How to load, process, and share audio datasets.
-- Vision: How to load, process, and share image datasets.
+- Vision: How to load, process, and share image and video datasets.
- Text: How to load, process, and share text datasets.
- Tabular: How to load, process, and share tabular datasets.
- Dataset repository: How to share and upload a dataset to the Hub.
diff --git a/docs/source/nlp_load.mdx b/docs/source/nlp_load.mdx
index 5cfe5d31e99..dae074ae3fc 100644
--- a/docs/source/nlp_load.mdx
+++ b/docs/source/nlp_load.mdx
@@ -33,4 +33,14 @@ To load remote text files via HTTP, pass the URLs instead:
```py
>>> dataset = load_dataset("text", data_files="https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt")
-```
\ No newline at end of file
+```
+
+To load XML data you can use the "xml" loader, which is equivalent to "text" with sample_by="document":
+
+```py
+>>> from datasets import load_dataset
+>>> dataset = load_dataset("xml", data_files={"train": ["my_xml_1.xml", "my_xml_2.xml"], "test": "my_xml_file.xml"})
+
+# Load from a directory
+>>> dataset = load_dataset("xml", data_dir="path/to/xml/dataset")
+```
diff --git a/docs/source/package_reference/loading_methods.mdx b/docs/source/package_reference/loading_methods.mdx
index b17cbed8a3b..b54ca4fa1eb 100644
--- a/docs/source/package_reference/loading_methods.mdx
+++ b/docs/source/package_reference/loading_methods.mdx
@@ -49,6 +49,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t")
[[autodoc]] datasets.packaged_modules.json.Json
+### XML
+
+[[autodoc]] datasets.packaged_modules.xml.XmlConfig
+
+[[autodoc]] datasets.packaged_modules.xml.Xml
+
### Parquet
[[autodoc]] datasets.packaged_modules.parquet.ParquetConfig
@@ -79,6 +85,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t")
[[autodoc]] datasets.packaged_modules.audiofolder.AudioFolder
+### Videos
+
+[[autodoc]] datasets.packaged_modules.videofolder.VideoFolderConfig
+
+[[autodoc]] datasets.packaged_modules.videofolder.VideoFolder
+
### WebDataset
[[autodoc]] datasets.packaged_modules.webdataset.WebDataset
diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx
index c08d555cd29..185bde10d72 100644
--- a/docs/source/package_reference/main_classes.mdx
+++ b/docs/source/package_reference/main_classes.mdx
@@ -171,6 +171,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
- batch
- skip
- take
+ - shard
- load_state_dict
- state_dict
- info
@@ -245,6 +246,10 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable
[[autodoc]] datasets.Image
+### Video
+
+[[autodoc]] datasets.Video
+
## Filesystems
[[autodoc]] datasets.filesystems.is_remote_filesystem
diff --git a/docs/source/process.mdx b/docs/source/process.mdx
index 11fd4a1e9b0..38989613ef3 100644
--- a/docs/source/process.mdx
+++ b/docs/source/process.mdx
@@ -736,7 +736,7 @@ Want to save your dataset to a cloud storage provider? Read our [Cloud Storage](
| JSON | [`Dataset.to_json`] |
| Parquet | [`Dataset.to_parquet`] |
| SQL | [`Dataset.to_sql`] |
-| In-memory Python object | [`Dataset.to_pandas`] or [`Dataset.to_dict`] |
+| In-memory Python object | [`Dataset.to_pandas`], [`Dataset.to_polars`] or [`Dataset.to_dict`] |
For example, export your dataset to a CSV file like this:
diff --git a/docs/source/share.mdx b/docs/source/share.mdx
index cb3d401fed5..0528d6dc917 100644
--- a/docs/source/share.mdx
+++ b/docs/source/share.mdx
@@ -9,7 +9,7 @@ Dataset repositories offer features such as:
- Commit history and diffs
- Metadata for discoverability
- Dataset cards for documentation, licensing, limitations, etc.
-- [Dataset Viewer](../hub/datasets-viewer)
+- [Dataset Viewer](https://huggingface.co/docs/hub/datasets-viewer)
This guide will show you how to share a dataset folder or repository that can be easily accessed by anyone.
@@ -68,13 +68,13 @@ Check your directory to ensure the only files you're uploading are:
## huggingface-cli upload
-Use the `huggingface-cli upload` command to upload files to the Hub directly. Internally, it uses the same [`upload_file`] and [`upload_folder`] helpers described in the [Upload guide](../huggingface_hub/guides/upload). In the examples below, we will walk through the most common use cases. For a full list of available options, you can run:
+Use the `huggingface-cli upload` command to upload files to the Hub directly. Internally, it uses the same [`upload_file`] and [`upload_folder`] helpers described in the [Upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload). In the examples below, we will walk through the most common use cases. For a full list of available options, you can run:
```bash
>>> huggingface-cli upload --help
```
-For more general information about `huggingface-cli` you can check the [CLI guide](../huggingface_hub/guides/cli).
+For more general information about `huggingface-cli` you can check the [CLI guide](https://huggingface.co/docs/huggingface_hub/guides/cli).
### Upload an entire folder
@@ -214,6 +214,6 @@ Congratulations, your dataset has now been uploaded to the Hugging Face Hub wher
dataset = load_dataset("Wauplin/my-cool-dataset")
```
-If your dataset is supported, it should also have a [Dataset Viewer](../hub/datasets-viewer) for everyone to explore the dataset content.
+If your dataset is supported, it should also have a [Dataset Viewer](https://huggingface.co/docs/hub/datasets-viewer) for everyone to explore the dataset content.
Finally, don't forget to enrich the dataset card to document your dataset and make it discoverable! Check out the [Create a dataset card](dataset_card) guide to learn more.
diff --git a/docs/source/stream.mdx b/docs/source/stream.mdx
index df5109b25aa..0be393ce4a8 100644
--- a/docs/source/stream.mdx
+++ b/docs/source/stream.mdx
@@ -136,6 +136,35 @@ You can split your dataset one of two ways:
+
+### Shard
+
+🤗 Datasets supports sharding to divide a very large dataset into a predefined number of chunks. Specify the `num_shards` parameter in [`~IterableDataset.shard`] to determine the number of shards to split the dataset into. You'll also need to provide the shard you want to return with the `index` parameter.
+
+For example, the [amazon_polarity](https://huggingface.co/datasets/amazon_polarity) dataset has 4 shards (in this case they are 4 Parquet files):
+
+```py
+>>> from datasets import load_dataset
+>>> dataset = load_dataset("amazon_polarity", split="train", streaming=True)
+>>> print(dataset)
+IterableDataset({
+ features: ['label', 'title', 'content'],
+ num_shards: 4
+})
+```
+
+After sharding the dataset into two chunks, the first one will only have 2 shards:
+
+```py
+>>> dataset.shard(num_shards=2, index=0)
+IterableDataset({
+ features: ['label', 'title', 'content'],
+ num_shards: 2
+})
+```
+
+If your dataset has `dataset.num_shards==1`, you should chunk it using [`IterableDataset.skip`] and [`IterableDataset.take`] instead.
+
## Interleave
[`interleave_datasets`] can combine an [`IterableDataset`] with other datasets. The combined dataset returns alternating examples from each of the original datasets.
diff --git a/docs/source/use_with_pytorch.mdx b/docs/source/use_with_pytorch.mdx
index 7f78d8de05c..375d6facc3e 100644
--- a/docs/source/use_with_pytorch.mdx
+++ b/docs/source/use_with_pytorch.mdx
@@ -216,7 +216,7 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi
```py
>>> my_iterable_dataset = load_dataset("deepmind/code_contests", streaming=True, split="train")
->>> my_iterable_dataset.n_shards
+>>> my_iterable_dataset.num_shards
39
>>> dataloader = DataLoader(my_iterable_dataset, batch_size=32, num_workers=4)
```
@@ -259,7 +259,7 @@ Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of t
For iterable datasets:
-If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
+If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
diff --git a/docs/source/video_dataset.mdx b/docs/source/video_dataset.mdx
new file mode 100644
index 00000000000..79cefbd294b
--- /dev/null
+++ b/docs/source/video_dataset.mdx
@@ -0,0 +1,172 @@
+# Create a video dataset
+
+This guide will show you how to create a video dataset with `VideoFolder` and some metadata. This is a no-code solution for quickly creating a video dataset with several thousand videos.
+
+
+
+You can control access to your dataset by requiring users to share their contact information first. Check out the [Gated datasets](https://huggingface.co/docs/hub/datasets-gated) guide for more information about how to enable this feature on the Hub.
+
+
+
+## VideoFolder
+
+The `VideoFolder` is a dataset builder designed to quickly load a video dataset with several thousand videos without requiring you to write any code.
+
+
+
+💡 Take a look at the [Split pattern hierarchy](repository_structure#split-pattern-hierarchy) to learn more about how `VideoFolder` creates dataset splits based on your dataset repository structure.
+
+
+
+`VideoFolder` automatically infers the class labels of your dataset based on the directory name. Store your dataset in a directory structure like:
+
+```
+folder/train/dog/golden_retriever.mp4
+folder/train/dog/german_shepherd.mp4
+folder/train/dog/chihuahua.mp4
+
+folder/train/cat/maine_coon.mp4
+folder/train/cat/bengal.mp4
+folder/train/cat/birman.mp4
+```
+
+Then users can load your dataset by specifying `videofolder` in [`load_dataset`] and the directory in `data_dir`:
+
+```py
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder")
+```
+
+You can also use `videofolder` to load datasets involving multiple splits. To do so, your dataset directory should have the following structure:
+
+```
+folder/train/dog/golden_retriever.mp4
+folder/train/cat/maine_coon.mp4
+folder/test/dog/german_shepherd.mp4
+folder/test/cat/bengal.mp4
+```
+
+
+
+If all video files are contained in a single directory or if they are not on the same level of directory structure, `label` column won't be added automatically. If you need it, set `drop_labels=False` explicitly.
+
+
+
+
+If there is additional information you'd like to include about your dataset, like text captions or bounding boxes, add it as a `metadata.csv` file in your folder. This lets you quickly create datasets for different computer vision tasks like text captioning or object detection. You can also use a JSONL file `metadata.jsonl`.
+
+```
+folder/train/metadata.csv
+folder/train/0001.mp4
+folder/train/0002.mp4
+folder/train/0003.mp4
+```
+
+You can also zip your videos:
+
+```
+folder/metadata.csv
+folder/train.zip
+folder/test.zip
+folder/valid.zip
+```
+
+Your `metadata.csv` file must have a `file_name` column which links video files with their metadata:
+
+```csv
+file_name,additional_feature
+0001.mp4,This is a first value of a text feature you added to your videos
+0002.mp4,This is a second value of a text feature you added to your videos
+0003.mp4,This is a third value of a text feature you added to your videos
+```
+
+or using `metadata.jsonl`:
+
+```jsonl
+{"file_name": "0001.mp4", "additional_feature": "This is a first value of a text feature you added to your videos"}
+{"file_name": "0002.mp4", "additional_feature": "This is a second value of a text feature you added to your videos"}
+{"file_name": "0003.mp4", "additional_feature": "This is a third value of a text feature you added to your videos"}
+```
+
+
+
+If metadata files are present, the inferred labels based on the directory name are dropped by default. To include those labels, set `drop_labels=False` in `load_dataset`.
+
+
+
+### Video captioning
+
+Video captioning datasets have text describing a video. An example `metadata.csv` may look like:
+
+```csv
+file_name,text
+0001.mp4,This is a golden retriever playing with a ball
+0002.mp4,A german shepherd
+0003.mp4,One chihuahua
+```
+
+Load the dataset with `VideoFolder`, and it will create a `text` column for the video captions:
+
+```py
+>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder", split="train")
+>>> dataset[0]["text"]
+"This is a golden retriever playing with a ball"
+```
+
+### Upload dataset to the Hub
+
+Once you've created a dataset, you can share it to the using `huggingface_hub` for example. Make sure you have the [huggingface_hub](https://huggingface.co/docs/huggingface_hub/index) library installed and you're logged in to your Hugging Face account (see the [Upload with Python tutorial](upload_dataset#upload-with-python) for more details).
+
+Upload your dataset with `huggingface_hub.HfApi.upload_folder`:
+
+```py
+from huggingface_hub import HfApi
+api = HfApi()
+
+api.upload_folder(
+ folder_path="/path/to/local/dataset",
+ repo_id="username/my-cool-dataset",
+ repo_type="dataset",
+)
+```
+
+## WebDataset
+
+The [WebDataset](https://github.com/webdataset/webdataset) format is based on TAR archives and is suitable for big video datasets.
+Indeed you can group your videos in TAR archives (e.g. 1GB of videos per TAR archive) and have thousands of TAR archives:
+
+```
+folder/train/00000.tar
+folder/train/00001.tar
+folder/train/00002.tar
+...
+```
+
+In the archives, each example is made of files sharing the same prefix:
+
+```
+e39871fd9fd74f55.mp4
+e39871fd9fd74f55.json
+f18b91585c4d3f3e.mp4
+f18b91585c4d3f3e.json
+ede6e66b2fb59aab.mp4
+ede6e66b2fb59aab.json
+ed600d57fcee4f94.mp4
+ed600d57fcee4f94.json
+...
+```
+
+You can put your videos labels/captions/features using JSON or text files for example.
+
+For more details on the WebDataset format and the python library, please check the [WebDataset documentation](https://webdataset.github.io/webdataset).
+
+Load your WebDataset and it will create on column per file suffix (here "mp4" and "json"):
+
+```python
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", split="train")
+>>> dataset[0]["json"]
+{"bbox": [[302.0, 109.0, 73.0, 52.0]], "categories": [0]}
+```
diff --git a/docs/source/video_load.mdx b/docs/source/video_load.mdx
new file mode 100644
index 00000000000..782869d2eed
--- /dev/null
+++ b/docs/source/video_load.mdx
@@ -0,0 +1,132 @@
+# Load video data
+
+
+
+Video support is experimental and is subject to change.
+
+
+
+Video datasets have [`Video`] type columns, which contain `decord` objects.
+
+
+
+To work with video datasets, you need to have the `vision` dependency installed. Check out the [installation](./installation#vision) guide to learn how to install it.
+
+
+
+When you load an video dataset and call the video column, the videos are decoded as `decord` Videos:
+
+```py
+>>> from datasets import load_dataset, Video
+
+>>> dataset = load_dataset("path/to/video/folder", split="train")
+>>> dataset[0]["video"]
+
+```
+
+
+
+Index into an video dataset using the row index first and then the `video` column - `dataset[0]["video"]` - to avoid reading all the video objects in the dataset. Otherwise, this can be a slow and time-consuming process if you have a large dataset.
+
+
+
+For a guide on how to load any type of dataset, take a look at the general loading guide.
+
+## Read frames
+
+Access frames directly from a video using the `VideoReader`:
+
+```python
+>>> dataset[0]["video"][0].shape # first frame
+(240, 320, 3)
+```
+
+To get multiple frames at once, use `get_batch`. This is the efficient way to obtain a long list of frames:
+
+```python
+>>> frames = dataset[0]["video"].get_batch([1, 3, 5, 7, 9])
+>>> frames.shape
+(5, 240, 320, 3)
+```
+
+## Local files
+
+You can load a dataset from the video path. Use the [`~Dataset.cast_column`] function to accept a column of video file paths, and decode it into a `decord` video with the [`Video`] feature:
+```py
+>>> from datasets import Dataset, Video
+
+>>> dataset = Dataset.from_dict({"video": ["path/to/video_1", "path/to/video_2", ..., "path/to/video_n"]}).cast_column("video", Video())
+>>> dataset[0]["video"]
+
+```
+
+If you only want to load the underlying path to the video dataset without decoding the video object, set `decode=False` in the [`Video`] feature:
+
+```py
+>>> dataset = dataset.cast_column("video", Video(decode=False))
+>>> dataset[0]["video"]
+{'bytes': None,
+ 'path': 'path/to/video/folder/video0.mp4'}
+```
+
+## VideoFolder
+
+You can also load a dataset with an `VideoFolder` dataset builder which does not require writing a custom dataloader. This makes `VideoFolder` ideal for quickly creating and loading video datasets with several thousand videos for different vision tasks. Your video dataset structure should look like this:
+
+```
+folder/train/dog/golden_retriever.mp4
+folder/train/dog/german_shepherd.mp4
+folder/train/dog/chihuahua.mp4
+
+folder/train/cat/maine_coon.mp4
+folder/train/cat/bengal.mp4
+folder/train/cat/birman.mp4
+```
+
+Load your dataset by specifying `videofolder` and the directory of your dataset in `data_dir`:
+
+```py
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder")
+>>> dataset["train"][0]
+{"video": , "label": 0}
+
+>>> dataset["train"][-1]
+{"video": , "label": 1}
+```
+
+Load remote datasets from their URLs with the `data_files` parameter:
+
+```py
+>>> dataset = load_dataset("videofolder", data_files="https://foo.bar/videos.zip", split="train")
+```
+
+Some datasets have a metadata file (`metadata.csv`/`metadata.jsonl`) associated with it, containing other information about the data like bounding boxes, text captions, and labels. The metadata is automatically loaded when you call [`load_dataset`] and specify `videofolder`.
+
+To ignore the information in the metadata file, set `drop_labels=False` in [`load_dataset`], and allow `VideoFolder` to automatically infer the label name from the directory name:
+
+```py
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder", drop_labels=False)
+```
+
+
+
+For more information about creating your own `VideoFolder` dataset, take a look at the [Create a video dataset](./video_dataset) guide.
+
+
+
+## WebDataset
+
+The [WebDataset](https://github.com/webdataset/webdataset) format is based on a folder of TAR archives and is suitable for big video datasets.
+Because of their size, WebDatasets are generally loaded in streaming mode (using `streaming=True`).
+
+You can load a WebDataset like this:
+
+```python
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True)
+```
diff --git a/setup.py b/setup.py
index 7c1391c1b61..65ea27eb731 100644
--- a/setup.py
+++ b/setup.py
@@ -187,6 +187,7 @@
"transformers>=4.42.0", # Pins numpy < 2
"zstandard",
"polars[timezone]>=0.20.0",
+ "decord==0.6.0",
]
@@ -234,7 +235,7 @@
setup(
name="datasets",
- version="3.0.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="3.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="HuggingFace community-driven open-source library of datasets",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py
index 6bb25aadabc..4691019efe1 100644
--- a/src/datasets/__init__.py
+++ b/src/datasets/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "3.0.2"
+__version__ = "3.1.0"
from .arrow_dataset import Dataset
from .arrow_reader import ReadInstruction
diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py
index b289fba4106..57f3024e53b 100644
--- a/src/datasets/arrow_dataset.py
+++ b/src/datasets/arrow_dataset.py
@@ -77,7 +77,7 @@
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
from .data_files import sanitize_patterns
from .download.streaming_download_manager import xgetsize
-from .features import Audio, ClassLabel, Features, Image, Sequence, Value
+from .features import Audio, ClassLabel, Features, Image, Sequence, Value, Video
from .features.features import (
FeatureType,
_align_features,
@@ -303,7 +303,7 @@ def _get_output_signature(
tf_dtype = tf.float32
np_dtype = np.float32
elif np_arrays[0].dtype.kind == "U": # Unicode strings
- np_dtype = np.unicode_
+ np_dtype = np.str_
tf_dtype = tf.string
else:
raise RuntimeError(
@@ -1416,9 +1416,9 @@ def save_to_disk(
"""
Saves a dataset to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.
- For [`Image`] and [`Audio`] data:
+ For [`Image`], [`Audio`] and [`Video`] data:
- All the Image() and Audio() data are stored in the arrow files.
+ All the Image(), Audio() and Video() data are stored in the arrow files.
If you want to store paths or urls, please use the Value("string") type.
Args:
@@ -4630,32 +4630,31 @@ def shard(
self,
num_shards: int,
index: int,
- contiguous: bool = False,
+ contiguous: bool = True,
keep_in_memory: bool = False,
indices_cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
) -> "Dataset":
"""Return the `index`-nth shard from dataset split into `num_shards` pieces.
- This shards deterministically. `dset.shard(n, i)` will contain all elements of dset whose
- index mod `n = i`.
+ This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks,
+ so it can be easily concatenated back together after processing. If `len(dataset) % n == l`, then the
+ first `l` dataset each have length `(len(dataset) // n) + 1`, and the remaining dataset have length `(len(dataset) // n)`.
+ `datasets.concatenate_datasets([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original.
- `dset.shard(n, i, contiguous=True)` will instead split dset into contiguous chunks,
- so it can be easily concatenated back together after processing. If `n % i == l`, then the
- first `l` shards will have length `(n // i) + 1`, and the remaining shards will have length `(n // i)`.
- `datasets.concatenate([dset.shard(n, i, contiguous=True) for i in range(n)])` will return
- a dataset with the same order as the original.
+ Note: n should be less or equal to the number of elements in the dataset `len(dataset)`.
+
+ On the other hand, `dataset.shard(n, i, contiguous=False)` contains all elements of the dataset whose index mod `n = i`.
Be sure to shard before using any randomizing operator (such as `shuffle`).
It is best if the shard operator is used early in the dataset pipeline.
-
Args:
num_shards (`int`):
How many shards to split the dataset into.
index (`int`):
Which shard to select and return.
- contiguous: (`bool`, defaults to `False`):
+ contiguous: (`bool`, defaults to `True`):
Whether to select contiguous blocks of indices for shards.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
@@ -4663,7 +4662,8 @@ def shard(
Provide the name of a path for the cache file. It is used to store the
indices of each shard instead of the automatically generated cache file name.
writer_batch_size (`int`, defaults to `1000`):
- Number of rows per write operation for the cache file writer.
+ This only concerns the indices mapping.
+ Number of indices per write operation for the cache file writer.
This value is a good trade-off between memory usage during the processing, and processing speed.
Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running `map`.
@@ -4729,7 +4729,7 @@ def to_csv(
**to_csv_kwargs (additional keyword arguments):
- Parameters to pass to pandas's [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_json.html).
+ Parameters to pass to pandas's [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html).
@@ -5065,7 +5065,7 @@ def _estimate_nbytes(self) -> int:
def extra_nbytes_visitor(array, feature):
nonlocal extra_nbytes
- if isinstance(feature, (Audio, Image)):
+ if isinstance(feature, (Audio, Image, Video)):
for x in array.to_pylist():
if x is not None and x["bytes"] is None and x["path"] is not None:
size = xgetsize(x["path"])
@@ -5249,15 +5249,16 @@ def _push_parquet_shards_to_hub(
shards = (self.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards))
if decodable_columns:
+ from .io.parquet import get_writer_batch_size
- def shards_with_embedded_external_files(shards):
+ def shards_with_embedded_external_files(shards: Iterator[Dataset]) -> Iterator[Dataset]:
for shard in shards:
format = shard.format
shard = shard.with_format("arrow")
shard = shard.map(
embed_table_storage,
batched=True,
- batch_size=1000,
+ batch_size=get_writer_batch_size(shard.features),
keep_in_memory=True,
)
shard = shard.with_format(**format)
@@ -5310,7 +5311,7 @@ def push_to_hub(
"""Pushes the dataset to the hub as a Parquet dataset.
The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
- The resulting Parquet files are self-contained by default. If your dataset contains [`Image`] or [`Audio`]
+ The resulting Parquet files are self-contained by default. If your dataset contains [`Image`], [`Audio`] or [`Video`]
data, the Parquet files will store the bytes of your images or audio files.
You can disable this by setting `embed_external_files` to `False`.
@@ -5399,6 +5400,13 @@ def push_to_hub(
>>> french_dataset = load_dataset("/", "fr")
```
"""
+ if "Video(" in str(self.features):
+ raise NotImplementedError(
+ "push_to_hub is not implemented for video datasets, instead you should upload the video files "
+ "using e.g. the huggingface_hub library and optionally upload a metadata.csv or metadata.jsonl "
+ "file containing other information like video captions, features or labels. More information "
+ "at https://huggingface.co/docs/datasets/main/en/video_load#videofolder"
+ )
if config_name == "data":
raise ValueError("`config_name` cannot be 'data'. Please, choose another name for configuration.")
diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py
index 3b9993736e4..23fd8b94b87 100644
--- a/src/datasets/arrow_writer.py
+++ b/src/datasets/arrow_writer.py
@@ -24,10 +24,11 @@
from fsspec.core import url_to_fs
from . import config
-from .features import Features, Image, Value
+from .features import Audio, Features, Image, Value, Video
from .features.features import (
FeatureType,
_ArrayXDExtensionType,
+ _visit,
cast_to_python_objects,
generate_from_arrow_type,
get_nested_type,
@@ -48,6 +49,45 @@
type_ = type # keep python's type function
+def get_writer_batch_size(features: Optional[Features]) -> Optional[int]:
+ """
+ Get the writer_batch_size that defines the maximum row group size in the parquet files.
+ The default in `datasets` is 1,000 but we lower it to 100 for image/audio datasets and 10 for videos.
+ This allows to optimize random access to parquet file, since accessing 1 row requires
+ to read its entire row group.
+
+ This can be improved to get optimized size for querying/iterating
+ but at least it matches the dataset viewer expectations on HF.
+
+ Args:
+ features (`datasets.Features` or `None`):
+ Dataset Features from `datasets`.
+ Returns:
+ writer_batch_size (`Optional[int]`):
+ Writer batch size to pass to a dataset builder.
+ If `None`, then it will use the `datasets` default.
+ """
+ if not features:
+ return None
+
+ batch_size = np.inf
+
+ def set_batch_size(feature: FeatureType) -> None:
+ nonlocal batch_size
+ if isinstance(feature, Image):
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS)
+ elif isinstance(feature, Audio):
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
+ elif isinstance(feature, Video):
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS)
+ elif isinstance(feature, Value) and feature.dtype == "binary":
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS)
+
+ _visit(features, set_batch_size)
+
+ return None if batch_size is np.inf else batch_size
+
+
class SchemaInferenceError(ValueError):
pass
@@ -340,7 +380,9 @@ def __init__(
self.fingerprint = fingerprint
self.disable_nullable = disable_nullable
- self.writer_batch_size = writer_batch_size or config.DEFAULT_MAX_BATCH_SIZE
+ self.writer_batch_size = (
+ writer_batch_size or get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE
+ )
self.update_features = update_features
self.with_metadata = with_metadata
self.unit = unit
diff --git a/src/datasets/builder.py b/src/datasets/builder.py
index 7328b90cbca..c3eee41c6e0 100644
--- a/src/datasets/builder.py
+++ b/src/datasets/builder.py
@@ -930,7 +930,8 @@ def incomplete_dir(dirname):
# Sync info
self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values())
self.info.download_checksums = dl_manager.get_recorded_sizes_checksums()
- self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
+ if self.info.download_size is not None:
+ self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
# Save info
self._save_info()
diff --git a/src/datasets/config.py b/src/datasets/config.py
index e2de170bcba..43801efcaef 100644
--- a/src/datasets/config.py
+++ b/src/datasets/config.py
@@ -67,6 +67,17 @@
except importlib.metadata.PackageNotFoundError:
pass
+
+DUCKDB_VERSION = "N/A"
+DUCKDB_AVAILABLE = importlib.util.find_spec("duckdb") is not None
+
+if DUCKDB_AVAILABLE:
+ try:
+ DUCKDB_VERSION = version.parse(importlib.metadata.version("duckdb"))
+ logger.info(f"Duckdb version {DUCKDB_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+
TF_VERSION = "N/A"
TF_AVAILABLE = False
@@ -129,6 +140,7 @@
IS_MP3_SUPPORTED = importlib.util.find_spec("soundfile") is not None and version.parse(
importlib.import_module("soundfile").__libsndfile_version__
) >= version.parse("1.1.0")
+DECORD_AVAILABLE = importlib.util.find_spec("decord") is not None
# Optional compression tools
RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None
@@ -192,6 +204,7 @@
PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS = 100
PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = 100
PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS = 100
+PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS = 10
# Offline mode
_offline = os.environ.get("HF_DATASETS_OFFLINE")
diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py
index f92a1a8afda..40ca0cd7312 100644
--- a/src/datasets/dataset_dict.py
+++ b/src/datasets/dataset_dict.py
@@ -1230,9 +1230,9 @@ def save_to_disk(
"""
Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`.
- For [`Image`] and [`Audio`] data:
+ For [`Image`], [`Audio`] and [`Video`] data:
- All the Image() and Audio() data are stored in the arrow files.
+ All the Image(), Audio() and Video() data are stored in the arrow files.
If you want to store paths or urls, please use the Value("string") type.
Args:
diff --git a/src/datasets/distributed.py b/src/datasets/distributed.py
index e036fabaf2c..4697948f342 100644
--- a/src/datasets/distributed.py
+++ b/src/datasets/distributed.py
@@ -18,7 +18,7 @@ def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> D
For iterable datasets:
- If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
+ If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py
index 9a1d1e3a53c..800f6821443 100644
--- a/src/datasets/download/streaming_download_manager.py
+++ b/src/datasets/download/streaming_download_manager.py
@@ -64,6 +64,8 @@ def __init__(
self._data_dir = data_dir
self._base_path = base_path or os.path.abspath(".")
self.download_config = download_config or DownloadConfig()
+ self.downloaded_size = None
+ self.record_checksums = False
@property
def manual_dir(self):
@@ -208,3 +210,9 @@ def iter_files(self, urlpaths: Union[str, List[str]]) -> Iterable[str]:
```
"""
return FilesIterable.from_urlpaths(urlpaths, download_config=self.download_config)
+
+ def manage_extracted_files(self):
+ pass
+
+ def get_recorded_sizes_checksums(self):
+ pass
diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py
index 35ebfb4ac0c..bf38042eb81 100644
--- a/src/datasets/features/__init__.py
+++ b/src/datasets/features/__init__.py
@@ -12,8 +12,10 @@
"Image",
"Translation",
"TranslationVariableLanguages",
+ "Video",
]
from .audio import Audio
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value
from .image import Image
from .translation import Translation, TranslationVariableLanguages
+from .video import Video
diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py
index 1d241e0b7b7..34622cd94d9 100644
--- a/src/datasets/features/features.py
+++ b/src/datasets/features/features.py
@@ -43,6 +43,7 @@
from .audio import Audio
from .image import Image, encode_pil_image
from .translation import Translation, TranslationVariableLanguages
+from .video import Video
logger = logging.get_logger(__name__)
@@ -1202,6 +1203,7 @@ class LargeList:
Array5D,
Audio,
Image,
+ Video,
]
@@ -1346,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0):
return list(obj)
# Object with special encoding:
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
- elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD)):
+ elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)):
return schema.encode_example(obj) if obj is not None else None
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
return obj
@@ -1397,7 +1399,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
else:
return decode_nested_example([schema.feature], obj)
# Object with special decoding:
- elif isinstance(schema, (Audio, Image)):
+ elif isinstance(schema, (Audio, Image, Video)):
# we pass the token to read and decode files from private repositories in streaming mode
if obj is not None and schema.decode:
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
@@ -1417,6 +1419,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
Array5D.__name__: Array5D,
Audio.__name__: Audio,
Image.__name__: Image,
+ Video.__name__: Video,
}
diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py
new file mode 100644
index 00000000000..2cde83930ac
--- /dev/null
+++ b/src/datasets/features/video.py
@@ -0,0 +1,296 @@
+import os
+from dataclasses import dataclass, field
+from io import BytesIO
+from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union
+
+import numpy as np
+import pyarrow as pa
+
+from .. import config
+from ..download.download_config import DownloadConfig
+from ..table import array_cast
+from ..utils.file_utils import is_local_path, xopen
+from ..utils.py_utils import string_to_dict
+
+
+if TYPE_CHECKING:
+ from decord import VideoReader
+
+ from .features import FeatureType
+
+
+@dataclass
+class Video:
+ """
+ **Experimental.** Video [`Feature`] to read video data from a video file.
+
+ Input: The Video feature accepts as input:
+ - A `str`: Absolute path to the video file (i.e. random access is allowed).
+ - A `dict` with the keys:
+
+ - `path`: String with relative path of the video file in a dataset repository.
+ - `bytes`: Bytes of the video file.
+
+ This is useful for archived files with sequential access.
+
+ - A `decord.VideoReader`: decord video reader object.
+
+ Args:
+ mode (`str`, *optional*):
+ The mode to convert the video to. If `None`, the native mode of the video is used.
+ decode (`bool`, defaults to `True`):
+ Whether to decode the video data. If `False`,
+ returns the underlying dictionary in the format `{"path": video_path, "bytes": video_bytes}`.
+
+ Examples:
+
+ ```py
+ >>> from datasets import Dataset, Video
+ >>> ds = Dataset.from_dict({"video":["path/to/Screen Recording.mov"]}).cast_column("video", Video())
+ >>> ds.features["video"]
+ Video(decode=True, id=None)
+ >>> ds[0]["video"]
+
+ >>> ds = ds.cast_column('video', Video(decode=False))
+ {'bytes': None,
+ 'path': 'path/to/Screen Recording.mov'}
+ ```
+ """
+
+ decode: bool = True
+ id: Optional[str] = None
+ # Automatically constructed
+ dtype: ClassVar[str] = "decord.VideoReader"
+ pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()})
+ _type: str = field(default="Video", init=False, repr=False)
+
+ def __post_init__(self):
+ if config.DECORD_AVAILABLE:
+ patch_decord()
+
+ def __call__(self):
+ return self.pa_type
+
+ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader"]) -> dict:
+ """Encode example into a format for Arrow.
+
+ Args:
+ value (`str`, `np.ndarray`, `VideoReader` or `dict`):
+ Data passed as input to Video feature.
+
+ Returns:
+ `dict` with "path" and "bytes" fields
+ """
+ if config.DECORD_AVAILABLE:
+ from decord import VideoReader
+
+ else:
+ VideoReader = None
+
+ if isinstance(value, list):
+ value = np.array(value)
+
+ if isinstance(value, str):
+ return {"path": value, "bytes": None}
+ elif isinstance(value, bytes):
+ return {"path": None, "bytes": value}
+ elif isinstance(value, np.ndarray):
+ # convert the video array to bytes
+ return encode_np_array(value)
+ elif VideoReader and isinstance(value, VideoReader):
+ # convert the decord video reader to bytes
+ return encode_decord_video(value)
+ elif value.get("path") is not None and os.path.isfile(value["path"]):
+ # we set "bytes": None to not duplicate the data if they're already available locally
+ return {"bytes": None, "path": value.get("path")}
+ elif value.get("bytes") is not None or value.get("path") is not None:
+ # store the video bytes, and path is used to infer the video format using the file extension
+ return {"bytes": value.get("bytes"), "path": value.get("path")}
+ else:
+ raise ValueError(
+ f"A video sample should have one of 'path' or 'bytes' but they are missing or None in {value}."
+ )
+
+ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader":
+ """Decode example video file into video data.
+
+ Args:
+ value (`str` or `dict`):
+ A string with the absolute video file path, a dictionary with
+ keys:
+
+ - `path`: String with absolute or relative video file path.
+ - `bytes`: The bytes of the video file.
+ token_per_repo_id (`dict`, *optional*):
+ To access and decode
+ video files from private repositories on the Hub, you can pass
+ a dictionary repo_id (`str`) -> token (`bool` or `str`).
+
+ Returns:
+ `decord.VideoReader`
+ """
+ if not self.decode:
+ raise RuntimeError("Decoding is disabled for this feature. Please use Video(decode=True) instead.")
+
+ if config.DECORD_AVAILABLE:
+ from decord import VideoReader
+
+ else:
+ raise ImportError("To support decoding videos, please install 'decord'.")
+
+ if token_per_repo_id is None:
+ token_per_repo_id = {}
+
+ path, bytes_ = value["path"], value["bytes"]
+ if bytes_ is None:
+ if path is None:
+ raise ValueError(f"A video should have one of 'path' or 'bytes' but both are None in {value}.")
+ else:
+ if is_local_path(path):
+ video = VideoReader(path)
+ else:
+ source_url = path.split("::")[-1]
+ pattern = (
+ config.HUB_DATASETS_URL
+ if source_url.startswith(config.HF_ENDPOINT)
+ else config.HUB_DATASETS_HFFS_URL
+ )
+ try:
+ repo_id = string_to_dict(source_url, pattern)["repo_id"]
+ token = token_per_repo_id.get(repo_id)
+ except ValueError:
+ token = None
+ download_config = DownloadConfig(token=token)
+ with xopen(path, "rb", download_config=download_config) as f:
+ bytes_ = BytesIO(f.read())
+ video = VideoReader(bytes_)
+ else:
+ video = VideoReader(BytesIO(bytes_))
+ return video
+
+ def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]:
+ """If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary."""
+ from .features import Value
+
+ return (
+ self
+ if self.decode
+ else {
+ "bytes": Value("binary"),
+ "path": Value("string"),
+ }
+ )
+
+ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArray]) -> pa.StructArray:
+ """Cast an Arrow array to the Video arrow storage type.
+ The Arrow types that can be converted to the Video pyarrow storage type are:
+
+ - `pa.string()` - it must contain the "path" data
+ - `pa.binary()` - it must contain the video bytes
+ - `pa.struct({"bytes": pa.binary()})`
+ - `pa.struct({"path": pa.string()})`
+ - `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter
+ - `pa.list(*)` - it must contain the video array data
+
+ Args:
+ storage (`Union[pa.StringArray, pa.StructArray, pa.ListArray]`):
+ PyArrow array to cast.
+
+ Returns:
+ `pa.StructArray`: Array in the Video arrow storage type, that is
+ `pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
+ """
+ if pa.types.is_string(storage.type):
+ bytes_array = pa.array([None] * len(storage), type=pa.binary())
+ storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())
+ elif pa.types.is_binary(storage.type):
+ path_array = pa.array([None] * len(storage), type=pa.string())
+ storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())
+ elif pa.types.is_struct(storage.type):
+ if storage.type.get_field_index("bytes") >= 0:
+ bytes_array = storage.field("bytes")
+ else:
+ bytes_array = pa.array([None] * len(storage), type=pa.binary())
+ if storage.type.get_field_index("path") >= 0:
+ path_array = storage.field("path")
+ else:
+ path_array = pa.array([None] * len(storage), type=pa.string())
+ storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
+ elif pa.types.is_list(storage.type):
+ bytes_array = pa.array(
+ [encode_np_array(np.array(arr))["bytes"] if arr is not None else None for arr in storage.to_pylist()],
+ type=pa.binary(),
+ )
+ path_array = pa.array([None] * len(storage), type=pa.string())
+ storage = pa.StructArray.from_arrays(
+ [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()
+ )
+ return array_cast(storage, self.pa_type)
+
+
+def video_to_bytes(video: "VideoReader") -> bytes:
+ """Convert a decord Video object to bytes using native compression if possible"""
+ raise NotImplementedError()
+
+
+def encode_decord_video(video: "VideoReader") -> dict:
+ if hasattr(video, "_hf_encoded"):
+ return video._hf_encoded
+ else:
+ raise NotImplementedError(
+ "Encoding a decord video is not implemented. "
+ "Please call `datasets.features.video.patch_decord()` before loading videos to enable this."
+ )
+
+
+def encode_np_array(array: np.ndarray) -> dict:
+ raise NotImplementedError()
+
+
+# Patching decord a little bit to:
+# 1. store the encoded video data {"path": ..., "bytes": ...} in `video._hf_encoded``
+# 2. set the decord bridge to numpy/torch/tf/jax using `video._hf_bridge_out` (per video instance) instead of decord.bridge.bridge_out (global)
+# This doesn't affect the normal usage of decord.
+
+
+def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None:
+ from decord.bridge import bridge_out
+
+ if hasattr(uri, "read"):
+ self._hf_encoded = {"bytes": uri.read(), "path": None}
+ uri.seek(0)
+ elif isinstance(uri, str):
+ self._hf_encoded = {"bytes": None, "path": uri}
+ self._hf_bridge_out = bridge_out
+ self._original_init(uri, *args, **kwargs)
+
+
+def _patched_next(self: "VideoReader", *args, **kwargs):
+ return self._hf_bridge_out(self._original_next(*args, **kwargs))
+
+
+def _patched_get_batch(self: "VideoReader", *args, **kwargs):
+ return self._hf_bridge_out(self._original_get_batch(*args, **kwargs))
+
+
+def patch_decord():
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ # Same for duckdb which crashes on import
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ if config.DUCKDB_AVAILABLE:
+ import duckdb # noqa
+ import decord.video_reader
+ from decord import VideoReader
+
+ if not hasattr(VideoReader, "_hf_patched"):
+ decord.video_reader.bridge_out = lambda x: x
+ VideoReader._original_init = VideoReader.__init__
+ VideoReader.__init__ = _patched_init
+ VideoReader._original_next = VideoReader.next
+ VideoReader.next = _patched_next
+ VideoReader._original_get_batch = VideoReader.get_batch
+ VideoReader.get_batch = _patched_get_batch
+ VideoReader._hf_patched = True
diff --git a/src/datasets/formatting/jax_formatter.py b/src/datasets/formatting/jax_formatter.py
index 8035341c5cd..e247b7b5822 100644
--- a/src/datasets/formatting/jax_formatter.py
+++ b/src/datasets/formatting/jax_formatter.py
@@ -100,11 +100,23 @@ def _tensorize(self, value):
default_dtype = {"dtype": jnp.int32}
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": jnp.float32}
- elif config.PIL_AVAILABLE and "PIL" in sys.modules:
+
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
+ if config.DECORD_AVAILABLE and "decord" in sys.modules:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ from decord import VideoReader
+
+ if isinstance(value, VideoReader):
+ value._hf_bridge_out = lambda x: jnp.array(np.asarray(x))
+ return value
# using global variable since `jaxlib.xla_extension.Device` is not serializable neither
# with `pickle` nor with `dill`, so we need to use a global variable instead
diff --git a/src/datasets/formatting/np_formatter.py b/src/datasets/formatting/np_formatter.py
index 95bcff2b517..032758bce21 100644
--- a/src/datasets/formatting/np_formatter.py
+++ b/src/datasets/formatting/np_formatter.py
@@ -57,11 +57,23 @@ def _tensorize(self, value):
default_dtype = {"dtype": np.int64}
elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": np.float32}
- elif config.PIL_AVAILABLE and "PIL" in sys.modules:
+
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(value, PIL.Image.Image):
return np.asarray(value, **self.np_array_kwargs)
+ if config.DECORD_AVAILABLE and "decord" in sys.modules:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ from decord import VideoReader
+
+ if isinstance(value, VideoReader):
+ value._hf_bridge_out = np.asarray
+ return value
return np.asarray(value, **{**default_dtype, **self.np_array_kwargs})
diff --git a/src/datasets/formatting/tf_formatter.py b/src/datasets/formatting/tf_formatter.py
index adb15cda381..9f0c06ec82a 100644
--- a/src/datasets/formatting/tf_formatter.py
+++ b/src/datasets/formatting/tf_formatter.py
@@ -64,11 +64,24 @@ def _tensorize(self, value):
default_dtype = {"dtype": tf.int64}
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": tf.float32}
- elif config.PIL_AVAILABLE and "PIL" in sys.modules:
+
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
+ if config.DECORD_AVAILABLE and "decord" in sys.modules:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ from decord import VideoReader
+ from decord.bridge import to_tensorflow
+
+ if isinstance(value, VideoReader):
+ value._hf_bridge_out = to_tensorflow
+ return value
return tf.convert_to_tensor(value, **{**default_dtype, **self.tf_tensor_kwargs})
diff --git a/src/datasets/formatting/torch_formatter.py b/src/datasets/formatting/torch_formatter.py
index 8efe759a144..051badb0ac4 100644
--- a/src/datasets/formatting/torch_formatter.py
+++ b/src/datasets/formatting/torch_formatter.py
@@ -66,7 +66,8 @@ def _tensorize(self, value):
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": torch.float32}
- elif config.PIL_AVAILABLE and "PIL" in sys.modules:
+
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(value, PIL.Image.Image):
@@ -75,6 +76,14 @@ def _tensorize(self, value):
value = value[:, :, np.newaxis]
value = value.transpose((2, 0, 1))
+ if config.DECORD_AVAILABLE and "decord" in sys.modules:
+ from decord import VideoReader
+ from decord.bridge import to_torch
+
+ if isinstance(value, VideoReader):
+ value._hf_bridge_out = to_torch
+ return value
+
return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
def _recursive_tensorize(self, data_struct):
diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py
index 51434106fb1..d34f5110204 100644
--- a/src/datasets/io/parquet.py
+++ b/src/datasets/io/parquet.py
@@ -2,11 +2,10 @@
from typing import BinaryIO, Optional, Union
import fsspec
-import numpy as np
import pyarrow.parquet as pq
-from .. import Audio, Dataset, Features, Image, NamedSplit, Value, config
-from ..features.features import FeatureType, _visit
+from .. import Dataset, Features, NamedSplit, config
+from ..arrow_writer import get_writer_batch_size
from ..formatting import query_table
from ..packaged_modules import _PACKAGED_DATASETS_MODULES
from ..packaged_modules.parquet.parquet import Parquet
@@ -15,41 +14,6 @@
from .abc import AbstractDatasetReader
-def get_writer_batch_size(features: Features) -> Optional[int]:
- """
- Get the writer_batch_size that defines the maximum row group size in the parquet files.
- The default in `datasets` is 1,000 but we lower it to 100 for image datasets.
- This allows to optimize random access to parquet file, since accessing 1 row requires
- to read its entire row group.
-
- This can be improved to get optimized size for querying/iterating
- but at least it matches the dataset viewer expectations on HF.
-
- Args:
- ds_config_info (`datasets.info.DatasetInfo`):
- Dataset info from `datasets`.
- Returns:
- writer_batch_size (`Optional[int]`):
- Writer batch size to pass to a dataset builder.
- If `None`, then it will use the `datasets` default.
- """
-
- batch_size = np.inf
-
- def set_batch_size(feature: FeatureType) -> None:
- nonlocal batch_size
- if isinstance(feature, Image):
- batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS)
- elif isinstance(feature, Audio):
- batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
- elif isinstance(feature, Value) and feature.dtype == "binary":
- batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS)
-
- _visit(features, set_batch_size)
-
- return None if batch_size is np.inf else batch_size
-
-
class ParquetDatasetReader(AbstractDatasetReader):
def __init__(
self,
diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py
index d1b54131b61..e38244383e2 100644
--- a/src/datasets/iterable_dataset.py
+++ b/src/datasets/iterable_dataset.py
@@ -141,16 +141,23 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples
"""
raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet")
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "_BaseExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable":
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet")
- def split_shard_indices_by_worker(self, worker_id: int, num_workers: int) -> List[int]:
- return list(range(worker_id, self.n_shards, num_workers))
+ def split_shard_indices_by_worker(self, num_shards: int, index: int, contiguous=True) -> List[int]:
+ if contiguous:
+ div = self.num_shards // num_shards
+ mod = self.num_shards % num_shards
+ start = div * index + min(index, mod)
+ end = start + div + (1 if index < mod else 0)
+ return list(range(start, end))
+ else:
+ return list(range(index, self.num_shards, num_shards))
@property
- def n_shards(self) -> int:
- raise NotImplementedError(f"{type(self)} doesn't implement n_shards yet")
+ def num_shards(self) -> int:
+ raise NotImplementedError(f"{type(self)} doesn't implement num_shards yet")
def _init_state_dict(self) -> dict:
raise NotImplementedError(f"{type(self)} doesn't implement _init_state_dict yet")
@@ -187,7 +194,7 @@ def _init_state_dict(self) -> dict:
def __iter__(self):
shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0
- for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None):
+ for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None):
shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0
for key_example in islice(self.generate_examples_fn(**gen_kwags), shard_example_idx_start, None):
if self._state_dict:
@@ -200,15 +207,15 @@ def __iter__(self):
def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable":
return ShuffledDataSourcesExamplesIterable(self.generate_examples_fn, self.kwargs, generator)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "ExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable":
"""Keep only the requested shard."""
- gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards)
- shard_indices = self.split_shard_indices_by_worker(worker_id, num_workers)
+ gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards)
+ shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous)
requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices])
return ExamplesIterable(self.generate_examples_fn, requested_gen_kwargs)
@property
- def n_shards(self) -> int:
+ def num_shards(self) -> int:
return _number_of_shards_in_gen_kwargs(self.kwargs)
@@ -229,7 +236,7 @@ def __iter__(self):
kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs)
shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0
for gen_kwags in islice(
- _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None
+ _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None
):
shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0
for key_example in islice(self.generate_examples_fn(**gen_kwags), shard_example_idx_start, None):
@@ -240,12 +247,12 @@ def __iter__(self):
self._state_dict["shard_idx"] += 1
self._state_dict["shard_example_idx"] = 0
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "ExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable":
"""Keep only the requested shard."""
rng = deepcopy(self.generator)
kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs)
return ExamplesIterable(self.generate_examples_fn, kwargs_with_shuffled_shards).shard_data_sources(
- worker_id, num_workers
+ num_shards, index, contiguous=contiguous
)
@@ -266,7 +273,7 @@ def _init_state_dict(self) -> dict:
def __iter__(self):
formatter = PythonFormatter()
shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0
- for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None):
+ for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None):
shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0
shard_example_idx = 0
for key, pa_table in self.generate_tables_fn(**gen_kwags):
@@ -287,7 +294,7 @@ def __iter__(self):
def _iter_arrow(self):
shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0
- for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards), shard_idx_start, None):
+ for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None):
shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0
shard_example_idx = 0
for key, pa_table in self.generate_tables_fn(**gen_kwags):
@@ -304,15 +311,15 @@ def _iter_arrow(self):
def shuffle_data_sources(self, generator: np.random.Generator) -> "ArrowExamplesIterable":
return ShuffledDataSourcesArrowExamplesIterable(self.generate_tables_fn, self.kwargs, generator)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "ArrowExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable":
"""Keep only the requested shard."""
- gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards)
- shard_indices = self.split_shard_indices_by_worker(worker_id, num_workers)
+ gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards)
+ shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous)
requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices])
return ArrowExamplesIterable(self.generate_tables_fn, requested_gen_kwargs)
@property
- def n_shards(self) -> int:
+ def num_shards(self) -> int:
return _number_of_shards_in_gen_kwargs(self.kwargs)
@@ -337,7 +344,7 @@ def __iter__(self):
formatter = PythonFormatter()
shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0
for gen_kwags in islice(
- _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None
+ _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None
):
shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0
shard_example_idx = 0
@@ -362,7 +369,7 @@ def _iter_arrow(self):
kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs)
shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0
for gen_kwags in islice(
- _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards), shard_idx_start, None
+ _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.num_shards), shard_idx_start, None
):
shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0
shard_example_idx = 0
@@ -377,12 +384,12 @@ def _iter_arrow(self):
self._state_dict["shard_idx"] += 1
self._state_dict["shard_example_idx"] = 0
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "ArrowExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable":
"""Keep only the requested shard."""
rng = deepcopy(self.generator)
kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs)
return ArrowExamplesIterable(self.generate_tables_fn, kwargs_with_shuffled_shards).shard_data_sources(
- worker_id, num_workers
+ num_shards, index, contiguous=contiguous
)
@@ -505,14 +512,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RebatchedArro
self.ex_iterable.shuffle_data_sources(generator), self.batch_size, self.drop_last_batch
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "RebatchedArrowExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RebatchedArrowExamplesIterable":
return RebatchedArrowExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers), self.batch_size, self.drop_last_batch
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
+ self.batch_size,
+ self.drop_last_batch,
)
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
class SelectColumnsIterable(_BaseExamplesIterable):
@@ -546,12 +555,14 @@ def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]:
def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumnsIterable":
return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "SelectColumnsIterable":
- return SelectColumnsIterable(self.ex_iterable.shard_data_sources(worker_id, num_workers), self.column_names)
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SelectColumnsIterable":
+ return SelectColumnsIterable(
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.column_names
+ )
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
class StepExamplesIterable(_BaseExamplesIterable):
@@ -584,14 +595,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "StepExamplesI
self.ex_iterable.shuffle_data_sources(generator), step=self.step, offset=self.offset
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "StepExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "StepExamplesIterable":
return StepExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers), step=self.step, offset=self.offset
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
+ step=self.step,
+ offset=self.offset,
)
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable):
@@ -679,13 +692,15 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "CyclingMultiS
return CyclingMultiSourcesExamplesIterable(ex_iterables, self.stopping_strategy)
@property
- def n_shards(self) -> int:
- return min(ex_iterable.n_shards for ex_iterable in self.ex_iterables)
+ def num_shards(self) -> int:
+ return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "CyclingMultiSourcesExamplesIterable":
+ def shard_data_sources(
+ self, num_shards: int, index: int, contiguous=True
+ ) -> "CyclingMultiSourcesExamplesIterable":
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
return CyclingMultiSourcesExamplesIterable(
- [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables],
+ [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables],
stopping_strategy=self.stopping_strategy,
)
@@ -748,15 +763,15 @@ def shuffle_data_sources(
return VerticallyConcatenatedMultiSourcesExamplesIterable(ex_iterables)
@property
- def n_shards(self) -> int:
- return min(ex_iterable.n_shards for ex_iterable in self.ex_iterables)
+ def num_shards(self) -> int:
+ return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
def shard_data_sources(
- self, worker_id: int, num_workers: int
+ self, num_shards: int, index: int, contiguous=True
) -> "VerticallyConcatenatedMultiSourcesExamplesIterable":
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
return VerticallyConcatenatedMultiSourcesExamplesIterable(
- [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables]
+ [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables]
)
@@ -829,15 +844,15 @@ def shuffle_data_sources(
return self
@property
- def n_shards(self) -> int:
+ def num_shards(self) -> int:
return 1
def shard_data_sources(
- self, worker_id: int, num_workers: int
+ self, num_shards: int, index: int, contiguous=True
) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable":
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
return HorizontallyConcatenatedMultiSourcesExamplesIterable(
- [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables]
+ [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables]
)
@@ -907,10 +922,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RandomlyCycli
stopping_strategy=self.stopping_strategy,
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "RandomlyCyclingMultiSourcesExamplesIterable":
+ def shard_data_sources(
+ self, num_shards: int, index: int, contiguous=True
+ ) -> "RandomlyCyclingMultiSourcesExamplesIterable":
"""Either keep only the requested shard, or propagate the request to the underlying iterable."""
return RandomlyCyclingMultiSourcesExamplesIterable(
- [iterable.shard_data_sources(worker_id, num_workers) for iterable in self.ex_iterables],
+ [iterable.shard_data_sources(num_shards, index, contiguous=contiguous) for iterable in self.ex_iterables],
self.generator,
self.probabilities,
self.stopping_strategy,
@@ -1161,10 +1178,10 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample
formatting=self.formatting,
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "MappedExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MappedExamplesIterable":
"""Keep only the requested shard."""
return MappedExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers),
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
function=self.function,
with_indices=self.with_indices,
input_columns=self.input_columns,
@@ -1177,8 +1194,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "MappedExample
)
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
class FilteredExamplesIterable(_BaseExamplesIterable):
@@ -1381,10 +1398,10 @@ def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable
batch_size=self.batch_size,
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "FilteredExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FilteredExamplesIterable":
"""Keep only the requested shard."""
return FilteredExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers),
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
function=self.function,
with_indices=self.with_indices,
input_columns=self.input_columns,
@@ -1393,8 +1410,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "FilteredExamp
)
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
class BufferShuffledExamplesIterable(_BaseExamplesIterable):
@@ -1451,17 +1468,17 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffle
self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "BufferShuffledExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable":
"""Keep only the requested shard."""
return BufferShuffledExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers),
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
buffer_size=self.buffer_size,
generator=self.generator,
)
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
class SkipExamplesIterable(_BaseExamplesIterable):
@@ -1514,12 +1531,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SkipExamplesI
split_when_sharding=self.split_when_sharding,
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "SkipExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SkipExamplesIterable":
"""Keep only the requested shard."""
if self.split_when_sharding:
return SkipExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers),
- n=self.split_number(self.n, num_workers)[worker_id],
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
+ n=self.split_number(self.n, num_shards)[index],
block_sources_order_when_shuffling=self.block_sources_order_when_shuffling,
split_when_sharding=self.split_when_sharding,
)
@@ -1527,8 +1544,8 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "SkipExamplesI
return self
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
class TakeExamplesIterable(_BaseExamplesIterable):
@@ -1582,26 +1599,26 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TakeExamplesI
split_when_sharding=self.split_when_sharding,
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "TakeExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TakeExamplesIterable":
"""Keep only the requested shard."""
if self.split_when_sharding:
return TakeExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers),
- n=self.split_number(self.n, num_workers)[worker_id],
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
+ n=self.split_number(self.n, num_shards)[index],
block_sources_order_when_shuffling=self.block_sources_order_when_shuffling,
split_when_sharding=self.split_when_sharding,
)
else:
return TakeExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers),
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
n=self.n,
block_sources_order_when_shuffling=self.block_sources_order_when_shuffling,
split_when_sharding=self.split_when_sharding,
)
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
def _apply_feature_types_on_example(
@@ -1690,17 +1707,17 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TypedExamples
token_per_repo_id=self.token_per_repo_id,
)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "TypedExamplesIterable":
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TypedExamplesIterable":
"""Keep only the requested shard."""
return TypedExamplesIterable(
- self.ex_iterable.shard_data_sources(worker_id, num_workers),
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
features=self.features,
token_per_repo_id=self.token_per_repo_id,
)
@property
- def n_shards(self) -> int:
- return self.ex_iterable.n_shards
+ def num_shards(self) -> int:
+ return self.ex_iterable.num_shards
@dataclass
@@ -1885,7 +1902,7 @@ def load_state_dict(self, state_dict: dict) -> None:
self._starting_state_dict = state_dict
def __repr__(self):
- return f"IterableDataset({{\n features: {list(self._info.features.keys()) if self._info.features is not None else 'Unknown'},\n n_shards: {self.n_shards}\n}})"
+ return f"IterableDataset({{\n features: {list(self._info.features.keys()) if self._info.features is not None else 'Unknown'},\n num_shards: {self.num_shards}\n}})"
def __getstate__(self):
return self.__dict__
@@ -1916,10 +1933,14 @@ def _effective_generator(self):
raise ValueError("This dataset is not shuffled")
@property
- def n_shards(self) -> int:
- if self._distributed and self._ex_iterable.n_shards % self._distributed.world_size == 0:
- return self._ex_iterable.n_shards // self._distributed.world_size
- return self._ex_iterable.n_shards
+ def num_shards(self) -> int:
+ if self._distributed and self._ex_iterable.num_shards % self._distributed.world_size == 0:
+ return self._ex_iterable.num_shards // self._distributed.world_size
+ return self._ex_iterable.num_shards
+
+ @property
+ def n_shards(self) -> int: # backward compatibility
+ return self.num_shards
def _iter_pytorch(self):
ex_iterable = self._prepare_ex_iterable_for_iteration()
@@ -1930,24 +1951,28 @@ def _iter_pytorch(self):
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
- if self._is_main_process() and ex_iterable.n_shards < worker_info.num_workers:
+ if self._is_main_process() and ex_iterable.num_shards < worker_info.num_workers:
logger.warning(
- f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={ex_iterable.n_shards}). "
- f"Stopping {worker_info.num_workers - ex_iterable.n_shards} dataloader workers."
+ f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.num_shards={ex_iterable.num_shards}). "
+ f"Stopping {worker_info.num_workers - ex_iterable.num_shards} dataloader workers."
)
logger.info(
f"To parallelize data loading, we give each process some shards (or data sources) to process. "
- f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={ex_iterable.n_shards}. "
- f"To enable more parallelism, please split the dataset in more files than {ex_iterable.n_shards}."
+ f"Therefore it's unnecessary to have a number of workers greater than dataset.num_shards={ex_iterable.num_shards}. "
+ f"To enable more parallelism, please split the dataset in more files than {ex_iterable.num_shards}."
)
# split workload
_log_prefix = f"node#{self._distributed.rank} " if self._distributed else ""
- shards_indices = ex_iterable.split_shard_indices_by_worker(worker_info.id, worker_info.num_workers)
+ shards_indices = ex_iterable.split_shard_indices_by_worker(
+ num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
+ )
if shards_indices:
logger.debug(
- f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards."
+ f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.num_shards} shards."
+ )
+ ex_iterable = ex_iterable.shard_data_sources(
+ num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
)
- ex_iterable = ex_iterable.shard_data_sources(worker_id=worker_info.id, num_workers=worker_info.num_workers)
self._state_dict = ex_iterable._init_state_dict()
if self._starting_state_dict:
ex_iterable.load_state_dict(self._starting_state_dict)
@@ -1978,11 +2003,11 @@ def _iter_pytorch(self):
)
yield format_dict(example) if format_dict else example
logger.debug(
- f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.n_shards} shards."
+ f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.num_shards} shards."
)
else:
logger.debug(
- f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.n_shards}<{worker_info.num_workers})."
+ f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.num_shards}<{worker_info.num_workers})."
)
def _is_main_process(self):
@@ -2012,14 +2037,14 @@ def _prepare_ex_iterable_for_iteration(
if self._distributed:
rank = self._distributed.rank
world_size = self._distributed.world_size
- if ex_iterable.n_shards % world_size == 0:
+ if ex_iterable.num_shards % world_size == 0:
if self._is_main_process():
- n_shards_per_node = ex_iterable.n_shards // world_size
- plural = "s" if n_shards_per_node > 1 else ""
+ num_shards_per_node = ex_iterable.num_shards // world_size
+ plural = "s" if num_shards_per_node > 1 else ""
logger.info(
- f"Assigning {n_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node."
+ f"Assigning {num_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node."
)
- ex_iterable = ex_iterable.shard_data_sources(rank, world_size)
+ ex_iterable = ex_iterable.shard_data_sources(num_shards=world_size, index=rank, contiguous=False)
else:
if self._is_main_process():
logger.info(
@@ -2028,7 +2053,7 @@ def _prepare_ex_iterable_for_iteration(
logger.info(
f"It is more optimized to distribute the dataset shards (or data sources) across nodes. "
f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. "
- f"The current dataset has {ex_iterable.n_shards} which is not a factor of {world_size}"
+ f"The current dataset has {ex_iterable.num_shards} which is not a factor of {world_size}"
)
ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank)
@@ -2635,6 +2660,63 @@ def take(self, n: int) -> "IterableDataset":
token_per_repo_id=self._token_per_repo_id,
)
+ def shard(
+ self,
+ num_shards: int,
+ index: int,
+ contiguous: bool = True,
+ ) -> "Dataset":
+ """Return the `index`-nth shard from dataset split into `num_shards` pieces.
+
+ This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks,
+ so it can be easily concatenated back together after processing. If `dataset.num_shards % n == l`, then the
+ first `l` datasets each have `(dataset.num_shards // n) + 1` shards, and the remaining datasets have `(dataset.num_shards // n)` shards.
+ `datasets.concatenate_datasets([dset.shard(n, i) for i in range(n)])` returns a dataset with the same order as the original.
+ In particular, `dataset.shard(dataset.num_shards, i)` returns a dataset with 1 shard.
+
+ Note: n should be less or equal to the number of shards in the dataset `dataset.num_shards`.
+
+ On the other hand, `dataset.shard(n, i, contiguous=False)` contains all the shards of the dataset whose index mod `n = i`.
+
+ Be sure to shard before using any randomizing operator (such as `shuffle`).
+ It is best if the shard operator is used early in the dataset pipeline.
+
+ Args:
+ num_shards (`int`):
+ How many shards to split the dataset into.
+ index (`int`):
+ Which shard to select and return.
+ contiguous: (`bool`, defaults to `True`):
+ Whether to select contiguous blocks of indices for shards.
+
+ Example:
+
+ ```py
+ >>> from datasets import load_dataset
+ >>> ds = load_dataset("amazon_polarity", split="train", streaming=True)
+ >>> ds
+ Dataset({
+ features: ['label', 'title', 'content'],
+ num_shards: 4
+ })
+ >>> ds.shard(num_shards=2, index=0)
+ Dataset({
+ features: ['label', 'title', 'content'],
+ num_shards: 2
+ })
+ ```
+ """
+ ex_iterable = self._ex_iterable.shard_data_sources(num_shards=num_shards, index=index, contiguous=contiguous)
+ return IterableDataset(
+ ex_iterable=ex_iterable,
+ info=self._info.copy(),
+ split=self._split,
+ formatting=self._formatting,
+ shuffling=copy.deepcopy(self._shuffling),
+ distributed=copy.deepcopy(self._distributed),
+ token_per_repo_id=self._token_per_repo_id,
+ )
+
@property
def column_names(self) -> Optional[List[str]]:
"""Names of the columns in the dataset.
@@ -3079,7 +3161,7 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
"""
Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
- If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
+ If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.num_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
diff --git a/src/datasets/load.py b/src/datasets/load.py
index 0faf2fd5cb5..ebdafafcd5f 100644
--- a/src/datasets/load.py
+++ b/src/datasets/load.py
@@ -1609,7 +1609,7 @@ def dataset_module_factory(
e.__cause__,
(
OfflineModeIsEnabled,
- requests.exceptions.ConnectTimeout,
+ requests.exceptions.Timeout,
requests.exceptions.ConnectionError,
),
):
@@ -1624,7 +1624,7 @@ def dataset_module_factory(
).sha
except (
OfflineModeIsEnabled,
- requests.exceptions.ConnectTimeout,
+ requests.exceptions.Timeout,
requests.exceptions.ConnectionError,
) as e:
raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e
diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py
index 6a23170db5e..94d806edf45 100644
--- a/src/datasets/packaged_modules/__init__.py
+++ b/src/datasets/packaged_modules/__init__.py
@@ -14,7 +14,9 @@
from .parquet import parquet
from .sql import sql
from .text import text
+from .videofolder import videofolder
from .webdataset import webdataset
+from .xml import xml
def _hash_python_lines(lines: List[str]) -> str:
@@ -40,7 +42,9 @@ def _hash_python_lines(lines: List[str]) -> str:
"text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())),
"imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())),
"audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())),
+ "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())),
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
+ "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())),
}
# get importable module names and hash for caching
@@ -69,12 +73,15 @@ def _hash_python_lines(lines: List[str]) -> str:
".arrow": ("arrow", {}),
".txt": ("text", {}),
".tar": ("webdataset", {}),
+ ".xml": ("xml", {}),
}
_EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext: ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS})
-_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder"}
+_EXTENSION_TO_MODULE.update({ext: ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS})
+_EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS})
+_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder", "videofolder"}
# Used to filter data files based on extensions given a module name
_MODULE_TO_EXTENSIONS: Dict[str, List[str]] = {}
diff --git a/src/datasets/packaged_modules/spark/spark.py b/src/datasets/packaged_modules/spark/spark.py
index c21cb3dd981..00b97f99c9f 100644
--- a/src/datasets/packaged_modules/spark/spark.py
+++ b/src/datasets/packaged_modules/spark/spark.py
@@ -100,12 +100,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SparkExamples
generator.shuffle(partition_order)
return SparkExamplesIterable(self.df, partition_order=partition_order)
- def shard_data_sources(self, worker_id: int, num_workers: int) -> "SparkExamplesIterable":
- partition_order = self.split_shard_indices_by_worker(worker_id, num_workers)
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SparkExamplesIterable":
+ partition_order = self.split_shard_indices_by_worker(num_shards=num_shards, index=index, contiguous=contiguous)
return SparkExamplesIterable(self.df, partition_order=partition_order)
@property
- def n_shards(self) -> int:
+ def num_shards(self) -> int:
return len(self.partition_order)
diff --git a/src/datasets/packaged_modules/videofolder/__init__.py b/src/datasets/packaged_modules/videofolder/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/src/datasets/packaged_modules/videofolder/videofolder.py b/src/datasets/packaged_modules/videofolder/videofolder.py
new file mode 100644
index 00000000000..7ce5bcf5655
--- /dev/null
+++ b/src/datasets/packaged_modules/videofolder/videofolder.py
@@ -0,0 +1,36 @@
+from typing import List
+
+import datasets
+
+from ..folder_based_builder import folder_based_builder
+
+
+logger = datasets.utils.logging.get_logger(__name__)
+
+
+class VideoFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
+ """BuilderConfig for ImageFolder."""
+
+ drop_labels: bool = None
+ drop_metadata: bool = None
+
+ def __post_init__(self):
+ super().__post_init__()
+
+
+class VideoFolder(folder_based_builder.FolderBasedBuilder):
+ BASE_FEATURE = datasets.Video
+ BASE_COLUMN_NAME = "video"
+ BUILDER_CONFIG_CLASS = VideoFolderConfig
+ EXTENSIONS: List[str] # definition at the bottom of the script
+
+
+# TODO: initial list, we should check the compatibility of other formats
+VIDEO_EXTENSIONS = [
+ ".mkv",
+ ".mp4",
+ ".avi",
+ ".mpeg",
+ ".mov",
+]
+VideoFolder.EXTENSIONS = VIDEO_EXTENSIONS
diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py
index c04f3ba4639..0768437b36a 100644
--- a/src/datasets/packaged_modules/webdataset/webdataset.py
+++ b/src/datasets/packaged_modules/webdataset/webdataset.py
@@ -20,6 +20,7 @@ class WebDataset(datasets.GeneratorBasedBuilder):
DEFAULT_WRITER_BATCH_SIZE = 100
IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script
AUDIO_EXTENSIONS: List[str] # definition at the bottom of the script
+ VIDEO_EXTENSIONS: List[str] # definition at the bottom of the script
DECODERS: Dict[str, Callable[[Any], Any]] # definition at the bottom of the script
NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5
@@ -97,6 +98,11 @@ def _split_generators(self, dl_manager):
extension = field_name.rsplit(".", 1)[-1]
if extension in self.AUDIO_EXTENSIONS:
features[field_name] = datasets.Audio()
+ # Set Video types
+ for field_name in first_examples[0]:
+ extension = field_name.rsplit(".", 1)[-1]
+ if extension in self.VIDEO_EXTENSIONS:
+ features[field_name] = datasets.Video()
self.info.features = features
return splits
@@ -259,6 +265,17 @@ def base_plus_ext(path):
WebDataset.AUDIO_EXTENSIONS = AUDIO_EXTENSIONS
+# TODO: initial list, we should check the compatibility of other formats
+VIDEO_EXTENSIONS = [
+ ".mkv",
+ ".mp4",
+ ".avi",
+ ".mpeg",
+ ".mov",
+]
+WebDataset.VIDEO_EXTENSIONS = VIDEO_EXTENSIONS
+
+
def text_loads(data: bytes):
return data.decode("utf-8")
diff --git a/src/datasets/packaged_modules/xml/__init__.py b/src/datasets/packaged_modules/xml/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/src/datasets/packaged_modules/xml/xml.py b/src/datasets/packaged_modules/xml/xml.py
new file mode 100644
index 00000000000..d5009b4dd6a
--- /dev/null
+++ b/src/datasets/packaged_modules/xml/xml.py
@@ -0,0 +1,68 @@
+import itertools
+from dataclasses import dataclass
+from typing import Optional
+
+import pyarrow as pa
+
+import datasets
+from datasets.features.features import require_storage_cast
+from datasets.table import table_cast
+
+
+logger = datasets.utils.logging.get_logger(__name__)
+
+
+@dataclass
+class XmlConfig(datasets.BuilderConfig):
+ """BuilderConfig for xml files."""
+
+ features: Optional[datasets.Features] = None
+ encoding: str = "utf-8"
+ encoding_errors: Optional[str] = None
+
+
+class Xml(datasets.ArrowBasedBuilder):
+ BUILDER_CONFIG_CLASS = XmlConfig
+
+ def _info(self):
+ return datasets.DatasetInfo(features=self.config.features)
+
+ def _split_generators(self, dl_manager):
+ """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]].
+
+ If str or List[str], then the dataset returns only the 'train' split.
+ If dict, then keys should be from the `datasets.Split` enum.
+ """
+ if not self.config.data_files:
+ raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
+ dl_manager.download_config.extract_on_the_fly = True
+ data_files = dl_manager.download_and_extract(self.config.data_files)
+ splits = []
+ for split_name, files in data_files.items():
+ if isinstance(files, str):
+ files = [files]
+ files = [dl_manager.iter_files(file) for file in files]
+ splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
+ return splits
+
+ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
+ if self.config.features is not None:
+ schema = self.config.features.arrow_schema
+ if all(not require_storage_cast(feature) for feature in self.config.features.values()):
+ # cheaper cast
+ pa_table = pa_table.cast(schema)
+ else:
+ # more expensive cast; allows str <-> int/float or str to Audio for example
+ pa_table = table_cast(pa_table, schema)
+ return pa_table
+ else:
+ return pa_table.cast(pa.schema({"xml": pa.string()}))
+
+ def _generate_tables(self, files):
+ pa_table_names = list(self.config.features) if self.config.features is not None else ["xml"]
+ for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
+ # open in text mode, by default translates universal newlines ("\n", "\r\n" and "\r") into "\n"
+ with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f:
+ xml = f.read()
+ pa_table = pa.Table.from_arrays([pa.array([xml])], names=pa_table_names)
+ yield file_idx, self._cast_table(pa_table)
diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py
index e44b1ce12bc..ab2d11bc71a 100644
--- a/src/datasets/utils/file_utils.py
+++ b/src/datasets/utils/file_utils.py
@@ -828,8 +828,8 @@ def read_with_retries(*args, **kwargs):
except (
aiohttp.client_exceptions.ClientError,
asyncio.TimeoutError,
- requests.exceptions.ConnectTimeout,
requests.exceptions.ConnectionError,
+ requests.exceptions.Timeout,
) as err:
disconnect_err = err
logger.warning(
diff --git a/src/datasets/utils/tf_utils.py b/src/datasets/utils/tf_utils.py
index 2de35a943e7..154f4d61072 100644
--- a/src/datasets/utils/tf_utils.py
+++ b/src/datasets/utils/tf_utils.py
@@ -278,7 +278,7 @@ def __init__(
self.cols_to_retain = cols_to_retain
self.collate_fn = collate_fn
self.collate_fn_args = collate_fn_args
- self.string_columns = [col for col, dtype in columns_to_np_types.items() if dtype in (np.unicode_, np.str_)]
+ self.string_columns = [col for col, dtype in columns_to_np_types.items() if dtype is np.str_]
# Strings will be converted to arrays of single unicode chars, so that we can have a constant itemsize
self.columns_to_np_types = {
col: dtype if col not in self.string_columns else np.dtype("U1")
diff --git a/tests/features/data/test_video_66x50.mov b/tests/features/data/test_video_66x50.mov
new file mode 100644
index 00000000000..a55dcaa8f7b
Binary files /dev/null and b/tests/features/data/test_video_66x50.mov differ
diff --git a/tests/features/test_video.py b/tests/features/test_video.py
new file mode 100644
index 00000000000..f4c9a8d830b
--- /dev/null
+++ b/tests/features/test_video.py
@@ -0,0 +1,92 @@
+import pytest
+
+from datasets import Dataset, Features, Video
+
+from ..utils import require_decord
+
+
+@require_decord
+@pytest.mark.parametrize(
+ "build_example",
+ [
+ lambda video_path: video_path,
+ lambda video_path: open(video_path, "rb").read(),
+ lambda video_path: {"path": video_path},
+ lambda video_path: {"path": video_path, "bytes": None},
+ lambda video_path: {"path": video_path, "bytes": open(video_path, "rb").read()},
+ lambda video_path: {"path": None, "bytes": open(video_path, "rb").read()},
+ lambda video_path: {"bytes": open(video_path, "rb").read()},
+ ],
+)
+def test_video_feature_encode_example(shared_datadir, build_example):
+ from decord import VideoReader
+
+ video_path = str(shared_datadir / "test_video_66x50.mov")
+ video = Video()
+ encoded_example = video.encode_example(build_example(video_path))
+ assert isinstance(encoded_example, dict)
+ assert encoded_example.keys() == {"bytes", "path"}
+ assert encoded_example["bytes"] is not None or encoded_example["path"] is not None
+ decoded_example = video.decode_example(encoded_example)
+ assert isinstance(decoded_example, VideoReader)
+
+
+@require_decord
+def test_dataset_with_video_feature(shared_datadir):
+ from decord import VideoReader
+ from decord.ndarray import NDArray
+
+ video_path = str(shared_datadir / "test_video_66x50.mov")
+ data = {"video": [video_path]}
+ features = Features({"video": Video()})
+ dset = Dataset.from_dict(data, features=features)
+ item = dset[0]
+ assert item.keys() == {"video"}
+ assert isinstance(item["video"], VideoReader)
+ assert item["video"][0].shape == (50, 66, 3)
+ assert isinstance(item["video"][0], NDArray)
+ batch = dset[:1]
+ assert len(batch) == 1
+ assert batch.keys() == {"video"}
+ assert isinstance(batch["video"], list) and all(isinstance(item, VideoReader) for item in batch["video"])
+ assert batch["video"][0][0].shape == (50, 66, 3)
+ assert isinstance(batch["video"][0][0], NDArray)
+ column = dset["video"]
+ assert len(column) == 1
+ assert isinstance(column, list) and all(isinstance(item, VideoReader) for item in column)
+ assert column[0][0].shape == (50, 66, 3)
+ assert isinstance(column[0][0], NDArray)
+
+ # from bytes
+ with open(video_path, "rb") as f:
+ data = {"video": [f.read()]}
+ dset = Dataset.from_dict(data, features=features)
+ item = dset[0]
+ assert item.keys() == {"video"}
+ assert isinstance(item["video"], VideoReader)
+ assert item["video"][0].shape == (50, 66, 3)
+ assert isinstance(item["video"][0], NDArray)
+
+
+@require_decord
+def test_dataset_with_video_map_and_formatted(shared_datadir):
+ import numpy as np
+ from decord import VideoReader
+
+ video_path = str(shared_datadir / "test_video_66x50.mov")
+ data = {"video": [video_path]}
+ features = Features({"video": Video()})
+ dset = Dataset.from_dict(data, features=features)
+ dset = dset.map(lambda x: x).with_format("numpy")
+ example = dset[0]
+ assert isinstance(example["video"], VideoReader)
+ assert isinstance(example["video"][0], np.ndarray)
+
+ # from bytes
+ with open(video_path, "rb") as f:
+ data = {"video": [f.read()]}
+ dset = Dataset.from_dict(data, features=features)
+ dset = dset.map(lambda x: x).with_format("numpy")
+ example = dset[0]
+ assert isinstance(example["video"], VideoReader)
+ assert isinstance(example["video"][0], np.ndarray)
diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py
index c91bdd571ea..89406c6a8dd 100644
--- a/tests/packaged_modules/test_spark.py
+++ b/tests/packaged_modules/test_spark.py
@@ -72,7 +72,7 @@ def test_spark_examples_iterable():
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("pyspark").getOrCreate()
df = spark.range(10).repartition(1)
it = SparkExamplesIterable(df)
- assert it.n_shards == 1
+ assert it.num_shards == 1
for i, (row_id, row_dict) in enumerate(it):
assert row_id == f"0_{i}"
assert row_dict == {"id": i}
@@ -89,7 +89,7 @@ def test_spark_examples_iterable_shuffle():
expected_row_ids_and_row_dicts = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [2, 1, 0])
shuffled_it = SparkExamplesIterable(df).shuffle_data_sources(generator_mock)
- assert shuffled_it.n_shards == 3
+ assert shuffled_it.num_shards == 3
for i, (row_id, row_dict) in enumerate(shuffled_it):
expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts[i]
assert row_id == expected_row_id
@@ -103,8 +103,8 @@ def test_spark_examples_iterable_shard():
df = spark.range(20).repartition(4)
# Partitions 0 and 2
- shard_it_1 = SparkExamplesIterable(df).shard_data_sources(worker_id=0, num_workers=2)
- assert shard_it_1.n_shards == 2
+ shard_it_1 = SparkExamplesIterable(df).shard_data_sources(index=0, num_shards=2, contiguous=False)
+ assert shard_it_1.num_shards == 2
expected_row_ids_and_row_dicts_1 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [0, 2])
for i, (row_id, row_dict) in enumerate(shard_it_1):
expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_1[i]
@@ -112,8 +112,8 @@ def test_spark_examples_iterable_shard():
assert row_dict == expected_row_dict
# Partitions 1 and 3
- shard_it_2 = SparkExamplesIterable(df).shard_data_sources(worker_id=1, num_workers=2)
- assert shard_it_2.n_shards == 2
+ shard_it_2 = SparkExamplesIterable(df).shard_data_sources(index=1, num_shards=2, contiguous=False)
+ assert shard_it_2.num_shards == 2
expected_row_ids_and_row_dicts_2 = _get_expected_row_ids_and_row_dicts_for_partition_order(df, [1, 3])
for i, (row_id, row_dict) in enumerate(shard_it_2):
expected_row_id, expected_row_dict = expected_row_ids_and_row_dicts_2[i]
diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py
index ffa048644e2..1e08862031b 100644
--- a/tests/test_arrow_dataset.py
+++ b/tests/test_arrow_dataset.py
@@ -2630,10 +2630,12 @@ def test_shard(self, in_memory):
tmp_file = os.path.join(tmp_dir, "test.arrow")
with dset.select(range(10), indices_cache_file_name=tmp_file) as dset:
self.assertEqual(len(dset), 10)
- # Shard
+ # Shard non-contiguous
tmp_file_1 = os.path.join(tmp_dir, "test_1.arrow")
fingerprint = dset._fingerprint
- with dset.shard(num_shards=8, index=1, indices_cache_file_name=tmp_file_1) as dset_sharded:
+ with dset.shard(
+ num_shards=8, index=1, contiguous=False, indices_cache_file_name=tmp_file_1
+ ) as dset_sharded:
self.assertEqual(2, len(dset_sharded))
self.assertEqual(["my_name-train_1", "my_name-train_9"], dset_sharded["filename"])
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
@@ -4268,7 +4270,7 @@ def test_dataset_to_iterable_dataset(dataset: Dataset):
assert isinstance(iterable_dataset, IterableDataset)
assert list(iterable_dataset) == list(dataset)
assert iterable_dataset.features == dataset.features
- assert iterable_dataset.n_shards == 3
+ assert iterable_dataset.num_shards == 3
with pytest.raises(ValueError):
dataset.to_iterable_dataset(num_shards=len(dataset) + 1)
with pytest.raises(NotImplementedError):
diff --git a/tests/test_distributed.py b/tests/test_distributed.py
index b8e0f56b180..65d2130f753 100644
--- a/tests/test_distributed.py
+++ b/tests/test_distributed.py
@@ -46,11 +46,11 @@ def gen(shards):
gen_kwargs = {"shards": [f"shard_{shard_idx}.txt" for shard_idx in range(num_shards)]}
full_ds = IterableDataset.from_generator(gen, gen_kwargs=gen_kwargs)
full_size = len(list(full_ds))
- assert full_ds.n_shards == world_size * shards_per_node
+ assert full_ds.num_shards == world_size * shards_per_node
datasets_per_rank = [
split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size)
]
- assert [ds.n_shards for ds in datasets_per_rank] == [shards_per_node] * world_size
+ assert [ds.num_shards for ds in datasets_per_rank] == [shards_per_node] * world_size
assert sum(len(list(ds)) for ds in datasets_per_rank) == full_size
assert len({tuple(x.values()) for ds in datasets_per_rank for x in ds}) == full_size
diff --git a/tests/test_hub.py b/tests/test_hub.py
index 9485fe83a71..13c496e0f6f 100644
--- a/tests/test_hub.py
+++ b/tests/test_hub.py
@@ -5,9 +5,10 @@
import pytest
from huggingface_hub import CommitOperationAdd, CommitOperationDelete
+from packaging import version
import datasets
-from datasets.config import METADATA_CONFIGS_FIELD
+from datasets.config import METADATA_CONFIGS_FIELD, PYARROW_VERSION
from datasets.hub import convert_to_parquet, delete_from_hub
from datasets.utils.hub import hf_dataset_url
@@ -83,7 +84,7 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_
- name: train
num_bytes: 55
num_examples: 5
- download_size: 790
+ download_size: 726
dataset_size: 55
{METADATA_CONFIGS_FIELD}:
- config_name: first
@@ -104,7 +105,7 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_
- name: train
num_bytes: 60
num_examples: 5
- download_size: 798
+ download_size: 732
dataset_size: 60
{METADATA_CONFIGS_FIELD}:
- config_name: second
@@ -114,6 +115,9 @@ def test_convert_to_parquet(temporary_repo, hf_api, hf_token, ci_hub_config, ci_
---
"""),
]
+ if PYARROW_VERSION < version.parse("18.0.0"):
+ expected_readmes[0] = expected_readmes[0].replace("download_size: 726", "download_size: 790")
+ expected_readmes[1] = expected_readmes[1].replace("download_size: 732", "download_size: 798")
for call_args, expected_commit_message, expected_create_pr, expected_readme, expected_parquet_path_in_repo in zip(
mock_create_commit.call_args_list,
["Convert dataset to Parquet", "Add 'second' config data files"],
diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py
index 232652f1fa3..44d3d24fa23 100644
--- a/tests/test_iterable_dataset.py
+++ b/tests/test_iterable_dataset.py
@@ -1277,7 +1277,7 @@ def gen(shard_names):
shard_names = [f"data{shard_idx}.txt" for shard_idx in range(4)]
dataset = IterableDataset.from_generator(gen, gen_kwargs={"shard_names": shard_names})
assert isinstance(dataset, IterableDataset)
- assert dataset.n_shards == len(shard_names)
+ assert dataset.num_shards == len(shard_names)
@require_numpy1_on_windows
@@ -1392,11 +1392,11 @@ def test_iterable_dataset_torch_dataloader_parallel():
@require_torch
@pytest.mark.filterwarnings("ignore:This DataLoader will create:UserWarning")
-@pytest.mark.parametrize("n_shards, num_workers", [(2, 1), (2, 2), (3, 2), (2, 3)])
-def test_sharded_iterable_dataset_torch_dataloader_parallel(n_shards, num_workers):
+@pytest.mark.parametrize("num_shards, num_workers", [(2, 1), (2, 2), (3, 2), (2, 3)])
+def test_sharded_iterable_dataset_torch_dataloader_parallel(num_shards, num_workers):
from torch.utils.data import DataLoader
- ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(n_shards)]})
+ ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}.txt" for i in range(num_shards)]})
dataset = IterableDataset(ex_iterable)
dataloader = DataLoader(dataset, batch_size=None, num_workers=num_workers)
result = list(dataloader)
@@ -1681,13 +1681,36 @@ def test_iterable_dataset_take(dataset: IterableDataset, n):
assert list(take_dataset) == list(dataset)[:n]
+def test_iterable_dataset_shard():
+ num_examples = 20
+ num_shards = 5
+ dataset = Dataset.from_dict({"a": range(num_examples)}).to_iterable_dataset(num_shards=num_shards)
+ assert sum(dataset.shard(num_shards, i).num_shards for i in range(num_shards)) == dataset.num_shards
+ assert list(concatenate_datasets([dataset.shard(num_shards, i) for i in range(num_shards)])) == list(dataset)
+ num_shards = 2
+ assert sum(dataset.shard(num_shards, i).num_shards for i in range(num_shards)) == dataset.num_shards
+ assert list(concatenate_datasets([dataset.shard(num_shards, i) for i in range(num_shards)])) == list(dataset)
+ assert (
+ sum(dataset.shard(num_shards, i, contiguous=False).num_shards for i in range(num_shards)) == dataset.num_shards
+ )
+ assert list(
+ concatenate_datasets([dataset.shard(num_shards, i, contiguous=False) for i in range(num_shards)])
+ ) != list(dataset)
+ assert sorted(
+ concatenate_datasets([dataset.shard(num_shards, i, contiguous=False) for i in range(num_shards)]),
+ key=lambda x: x["a"],
+ ) == list(dataset)
+
+
@pytest.mark.parametrize("method", ["skip", "take"])
@pytest.mark.parametrize("after_shuffle", [False, True])
@pytest.mark.parametrize("count", [2, 5, 11])
def test_iterable_dataset_skip_or_take_after_shuffle(method, after_shuffle, count):
seed = 42
- n, n_shards = 3, 10
- ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]})
+ n, num_shards = 3, 10
+ ex_iterable = ExamplesIterable(
+ generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(num_shards)]}
+ )
dataset = IterableDataset(ex_iterable)
shuffled_dataset = dataset
if after_shuffle:
@@ -1714,9 +1737,11 @@ def test_iterable_dataset_skip_or_take_after_shuffle(method, after_shuffle, coun
@pytest.mark.parametrize("after_split_by_node", [False, True])
@pytest.mark.parametrize("count", [2, 5, 11])
def test_iterable_dataset_skip_or_take_after_split_by_node(method, after_split_by_node, count):
- n, n_shards = 3, 10
+ n, num_shards = 3, 10
rank, world_size = 1, 2
- ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]})
+ ex_iterable = ExamplesIterable(
+ generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(num_shards)]}
+ )
dataset = IterableDataset(ex_iterable)
distributed_dataset = dataset
true_distributed_dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
@@ -2114,17 +2139,17 @@ def add_one_numpy(example):
assert isinstance(next(dataset.iter(batch_size=3))["id"], list)
-@pytest.mark.parametrize("n_shards1, n_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)])
-def test_interleave_dataset_with_sharding(n_shards1, n_shards2, num_workers):
+@pytest.mark.parametrize("num_shards1, num_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)])
+def test_interleave_dataset_with_sharding(num_shards1, num_shards2, num_workers):
from torch.utils.data import DataLoader
- ex_iterable1 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-1.txt" for i in range(n_shards1)]})
+ ex_iterable1 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-1.txt" for i in range(num_shards1)]})
dataset1 = IterableDataset(ex_iterable1).with_format("torch")
- ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(n_shards2)]})
+ ex_iterable2 = ExamplesIterable(generate_examples_fn, {"filepaths": [f"{i}-2.txt" for i in range(num_shards2)]})
dataset2 = IterableDataset(ex_iterable2).with_format("torch")
dataset_merged = interleave_datasets([dataset1, dataset2], stopping_strategy="first_exhausted")
- assert dataset_merged.n_shards == min(n_shards1, n_shards2)
+ assert dataset_merged.num_shards == min(num_shards1, num_shards2)
dataloader = DataLoader(dataset_merged, batch_size=None, num_workers=num_workers)
result = list(dataloader)
expected_length = 2 * min(
diff --git a/tests/test_offline_util.py b/tests/test_offline_util.py
index 7bea143df4b..22a372205ad 100644
--- a/tests/test_offline_util.py
+++ b/tests/test_offline_util.py
@@ -16,11 +16,11 @@ def test_offline_with_timeout():
with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT):
with pytest.raises(RequestWouldHangIndefinitelyError):
requests.request("GET", "https://huggingface.co")
- with pytest.raises(requests.exceptions.ConnectTimeout):
+ with pytest.raises(requests.exceptions.Timeout):
requests.request("GET", "https://huggingface.co", timeout=1.0)
# old versions of `huggingface_hub` don't have timeouts by default and don't allow to set timeouts in HfFileSystem
if version.parse(huggingface_hub.__version__) >= version.parse("0.23.0"):
- with pytest.raises(requests.exceptions.ConnectTimeout), NamedTemporaryFile() as temp_file:
+ with pytest.raises(requests.exceptions.Timeout), NamedTemporaryFile() as temp_file:
fsspec_get("hf://dummy", temp_file=temp_file)
diff --git a/tests/utils.py b/tests/utils.py
index e19740a2a12..08497e1eae7 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -178,6 +178,18 @@ def require_pil(test_case):
return test_case
+def require_decord(test_case):
+ """
+ Decorator marking a test that requires decord.
+
+ These tests are skipped when decord isn't installed.
+
+ """
+ if not config.DECORD_AVAILABLE:
+ test_case = unittest.skip("test requires decord")(test_case)
+ return test_case
+
+
def require_transformers(test_case):
"""
Decorator marking a test that requires transformers.