8000 Drop incomplete batches for Ray and Pandas to prevent Batchnorm computation errors by arnavgarg1 · Pull Request #2778 · ludwig-ai/ludwig · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Drop incomplete batches for Ray and Pandas to prevent Batchnorm computation errors #2778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion ludwig/data/batcher/random_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
import math

from ludwig.api_annotations import DeveloperAPI
from ludwig.data.batcher.base import Batcher

logger = logging.getLogger(__name__)


@DeveloperAPI
class RandomAccessBatcher(Batcher):
def __init__(self, dataset, sampler, batch_size=128, ignore_last=False):
# store our dataset as well
Expand Down Expand Up @@ -52,7 +57,19 @@ def next_batch(self):
return sub_batch

def last_batch(self):
return self.index >= self.total_size or (self.ignore_last and self.index + self.batch_size >= self.total_size)
# If our current index in the dataset exceeds the size of the dataset,
# we've finished the epoch and can indicate that this is the last batch
if self.index >= self.total_size:
return True
# This avoids the case where batch size > total size and no steps have been done.
# For e.g., batch size = 128 but the dataset only has 100 rows.
elif self.ignore_last and self.step:
# index += batch_size after each epoch. So, if our current index in total dataset is 1 less than the total
# dataset size, then the last batch will only have 1 row. Drop it if this happens.
if self.index - self.total_size == -1:
logger.info("Last batch in epoch only has 1 sample and will be dropped.")
return True
return False

def set_epoch(self, epoch, batch_size):
self.batch_size = batch_size
Expand Down
26 changes: 23 additions & 3 deletions ludwig/data/dataset/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ray.data.dataset_pipeline import DatasetPipeline
from ray.data.extensions import TensorDtype

from ludwig.api_annotations import DeveloperAPI
from ludwig.backend.base import Backend
from ludwig.constants import BINARY, CATEGORY, NAME, NUMBER, TYPE
from ludwig.data.batcher.base import Batcher
Expand Down Expand Up @@ -62,12 +63,14 @@ def cast_as_tensor_dtype(series: Series) -> Series:
return series.astype(TensorDtype())


@DeveloperAPI
@default_retry()
def read_remote_parquet(path: str):
fs, path = get_fs_and_path(path)
return read_parquet(path, filesystem=PyFileSystem(FSSpecHandler(fs)))


@DeveloperAPI
class RayDataset(Dataset):
"""Wrapper around ray.data.Dataset.

Expand Down Expand Up @@ -146,13 +149,21 @@ def pipeline(
return pipe

@contextlib.contextmanager
def initialize_batcher(self, batch_size=128, should_shuffle=True, seed=0, ignore_last=False, horovod=None):
def initialize_batcher(
self,
batch_size=128,
should_shuffle=True,
seed=0,
ignore_last=False,
horovod=None,
):
yield RayDatasetBatcher(
self.ds.repeat().iter_datasets(),
self.features,
self.training_set_metadata,
batch_size,
self.size,
ignore_last,
)

def __len__(self):
Expand All @@ -176,6 +187,7 @@ def to_df(self):
return self.df_engine.from_ray_dataset(self.ds)


@DeveloperAPI
class RayDatasetManager(DatasetManager):
def __init__(self, backend):
self.backend = backend
Expand Down Expand Up @@ -215,6 +227,7 @@ def data_format(self):
return "parquet"


@DeveloperAPI
class RayDatasetShard(Dataset):
def __init__(
self,
Expand All @@ -235,6 +248,7 @@ def initialize_batcher(self, batch_size=128, should_shuffle=True, seed=0, ignore
self.training_set_metadata,
batch_size,
self.size,
ignore_last,
)

@lru_cache(1)
Expand All @@ -247,6 +261,7 @@ def size(self):
return len(self)


@DeveloperAPI
class RayDatasetBatcher(Batcher):
def __init__(
self,
Expand All @@ -255,14 +270,17 @@ def __init__(
training_set_metadata: Dict[str, Any],
batch_size: int,
samples_per_epoch: int,
ignore_last: bool = False,
):
self.dataset_epoch_iterator = dataset_epoch_iterator
self.batch_size = batch_size
self.samples_per_epoch = samples_per_epoch
self.training_set_metadata = training_set_metadata
self.ignore_last = ignore_last

self.features = features
self.columns = list(features.keys())
self._sample_feature_name = self.columns[0]
self.reshape_map = {
proc_column: training_set_metadata[feature[NAME]].get("reshape")
for proc_column, feature in features.items()
Expand Down Expand Up @@ -325,6 +343,10 @@ def _fetch_next_batch(self):
self._last_batch = False
try:
self._next_batch = next(self.dataset_batch_iter)
# If the batch has only one row and self.ignore_last, skip the batch
# to prevent batchnorm / dropout related Torch errors
if self.ignore_last and len(self._next_batch[self._sample_feature_name]) == 1:
raise StopIteration
except StopIteration:
self._last_batch = True

Expand Down Expand Up @@ -372,9 +394,7 @@ def sync_read():

def _create_async_reader(self, pipeline: DatasetPipeline):
q = queue.Queue(maxsize=100)

batch_size = self.batch_size

to_tensors = self._to_tensors_fn()

def producer():
Expand Down
4 changes: 2 additions & 2 deletions ludwig/models/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def batch_predict(self, dataset: Dataset, dataset_name: str = None, collect_logi
self.model.eval() # set model to eval mode

with torch.no_grad():
with dataset.initialize_batcher(self._batch_size, should_shuffle=False, horovod=self._horovod) as batcher:
with dataset.initialize_batcher(self._batch_size, should_shuffle=False) as batcher:

progress_bar_config = {
"desc": "Prediction" if dataset_name is None else f"Prediction {dataset_name: <5.5}",
Expand Down Expand Up @@ -234,7 +234,7 @@ def batch_collect_activations(self, layer_names, dataset, bucketing_field=None):
self.model.eval() # set model to eval mode

with torch.no_grad():
with dataset.initialize_batcher(self._batch_size, should_shuffle=False) as batcher:
with dataset.initialize_batcher(self._batch_size, should_shuffle=False, horovod=self._horovod) as batcher:
progress_bar_config = {
"desc": "Collecting Tensors",
"total": batcher.steps_per_epoch,
Expand Down
8 changes: 6 additions & 2 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,10 @@ def get_optimal_lr(losses, learning_rates, skip_begin: int = 10, skip_end: int =

self.model.train() # Sets model training mode.
with training_set.initialize_batcher(
batch_size=self.batch_size, should_shuffle=self.should_shuffle, horovod=self.horovod
batch_size=self.batch_size,
should_shuffle=self.should_shuffle,
horovod=self.horovod,
ignore_last=True,
) as batcher:
step_count = 0
while epoch < self.epochs and step_count < total_training_steps and not diverging:
Expand Down Expand Up @@ -796,6 +799,7 @@ def train(self, training_set, validation_set=None, test_set=None, save_path="mod
should_shuffle=self.should_shuffle,
seed=self.random_seed,
horovod=self.horovod,
ignore_last=True,
) as batcher:
# ================ Training Loop ================
self.total_steps = get_total_steps(self.epochs, batcher.steps_per_epoch, self.train_steps)
Expand Down Expand Up @@ -1040,7 +1044,7 @@ def _train_loop(
def train_online(self, dataset):
self.model.train() # Sets model training mode.
with dataset.initialize_batcher(
batch_size=self.batch_size, should_shuffle=self.should_shuffle, horovod=self.horovod
batch_size=self.batch_size, should_shuffle=self.should_shuffle, horovod=self.horovod, ignore_last=True
) as batcher:

# training step loop
Expand Down
0