diff --git a/opacus/data_loader.py b/opacus/data_loader.py index 0f45bbfb..5350e725 100644 --- a/opacus/data_loader.py +++ b/opacus/data_loader.py @@ -22,7 +22,7 @@ ) from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler from torch.utils.data._utils.collate import default_collate -from torch.utils.data.dataloader import _collate_fn_t, _worker_init_fn_t +from torch.utils.data.dataloader import _collate_fn_t logger = logging.getLogger(__name__) @@ -114,17 +114,11 @@ def __init__( dataset: Dataset, *, sample_rate: float, - num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, - pin_memory: bool = False, drop_last: bool = False, - timeout: float = 0, - worker_init_fn: Optional[_worker_init_fn_t] = None, - multiprocessing_context=None, generator=None, - prefetch_factor: int = 2, - persistent_workers: bool = False, distributed: bool = False, + **kwargs, ): """ @@ -175,19 +169,13 @@ def __init__( super().__init__( dataset=dataset, batch_sampler=batch_sampler, - num_workers=num_workers, collate_fn=wrap_collate_with_empty( collate_fn=collate_fn, sample_empty_shapes=sample_empty_shapes, dtypes=dtypes, ), - pin_memory=pin_memory, - timeout=timeout, - worker_init_fn=worker_init_fn, - multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers, + **kwargs, ) @classmethod