diff --git a/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py b/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py index 2598aa2..b0fb041 100644 --- a/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py +++ b/src/jax_loop_utils/metric_writers/_audio_video/audio_video.py @@ -25,21 +25,44 @@ def encode_video(video_array: Array, destination: io.IOBase): Args: video_array: array to encode. Must have shape (T, H, W, 1) or (T, H, W, 3), where T is the number of frames, H is the height, W is the width, and the last - dimension is the number of channels. Must have dtype uint8. + dimension is the number of channels. + Must be ints in [0, 255] or floats in [0, 1]. destination: Destination to write the encoded video. """ video_array = np.array(video_array) + if video_array.ndim != 4 or video_array.shape[-1] not in (1, 3): + raise ValueError( + "Expected an array with shape (T, H, W, 1) or (T, H, W, 3)." + f"Got shape {video_array.shape} with dtype {video_array.dtype}." + ) + if ( - video_array.dtype != np.uint8 - or video_array.ndim != 4 - or video_array.shape[-1] not in (1, 3) + np.issubdtype(video_array.dtype, np.floating) + and np.all(0 <= video_array) + and np.all(video_array <= 1.0) ): + video_array = (video_array * 255).astype(np.uint8) + elif ( + np.issubdtype(video_array.dtype, np.integer) + and np.all(0 <= video_array) + and np.all(video_array <= 255) + ): + video_array = video_array.astype(np.uint8) + else: raise ValueError( - "Expected a uint8 array with shape (T, H, W, 1) or (T, H, W, 3)." - f"Got shape {video_array.shape} with dtype {video_array.dtype}." + f"Expected video_array to be floats in [0, 1] or ints in [0, 255], got {video_array.dtype}" ) T, H, W, C = video_array.shape + # Pad height and width to even numbers if necessary + pad_h = H % 2 + pad_w = W % 2 + if pad_h or pad_w: + padding = [(0, 0), (0, pad_h), (0, pad_w), (0, 0)] + video_array = np.pad(video_array, padding, mode="constant") + H += pad_h + W += pad_w + is_grayscale = C == 1 if is_grayscale: video_array = np.squeeze(video_array, axis=-1) diff --git a/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py b/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py index 3609dd8..dfdedd6 100644 --- a/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py +++ b/src/jax_loop_utils/metric_writers/_audio_video/audio_video_test.py @@ -19,17 +19,19 @@ class VideoTest(absltest.TestCase): def test_encode_video_invalid_args(self): """Test that encode_video raises appropriate errors for invalid inputs.""" invalid_shape = np.zeros((10, 20, 30, 4), dtype=np.uint8) - with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"): + with self.assertRaisesRegex(ValueError, r"Expected an array with shape"): encode_video(invalid_shape, io.BytesIO()) - invalid_dtype = np.zeros((10, 20, 30, 3), dtype=np.float32) - with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"): + invalid_dtype = 2 * np.ones((10, 20, 30, 3), dtype=np.float32) + with self.assertRaisesRegex( + ValueError, r"Expected video_array to be floats in \[0, 1\]" + ): encode_video(invalid_dtype, io.BytesIO()) def test_encode_video_success(self): """Test successful video encoding.""" # Create a simple test video - red square moving diagonally - T, H, W = 20, 64, 64 + T, H, W = 20, 63, 63 # test non-even dimensions video = np.zeros((T, H, W, 3), dtype=np.uint8) for t in range(T): pos = t * 5 # Move 5 pixels each frame @@ -41,8 +43,8 @@ def test_encode_video_success(self): output.seek(0) with av.open(output, mode="r", format=CONTAINER_FORMAT) as container: stream = container.streams.video[0] - self.assertEqual(stream.codec_context.width, W) - self.assertEqual(stream.codec_context.height, H) + self.assertEqual(stream.codec_context.width, W + 1) + self.assertEqual(stream.codec_context.height, H + 1) self.assertEqual(stream.codec_context.framerate, FPS) # Check we can decode all frames frame_count = sum(1 for _ in container.decode(stream)) diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py index 95c4a35..25c2b44 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer.py @@ -148,12 +148,11 @@ def _encode_and_log_video( temp_path.parent.mkdir(parents=True, exist_ok=True) with open(temp_path, "wb") as f: _audio_video.encode_video(video_array, f) # pyright: ignore[reportOptionalMemberAccess] + dest_dir = os.path.join("videos", os.path.dirname(rel_path)).rstrip("/") # If log_artifact(synchronous=False) existed, # we could synchronize with self.flush() rather than at the end of write_videos. # https://github.com/mlflow/mlflow/issues/14153 - self._client.log_artifact( - self._run_id, temp_path, os.path.join("videos", rel_path) - ) + self._client.log_artifact(self._run_id, temp_path, dest_dir) def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int): """MLflow doesn't support audio logging directly.""" diff --git a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py index 3771fcf..16cf0a5 100644 --- a/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py +++ b/src/jax_loop_utils/metric_writers/mlflow/metric_writer_test.py @@ -156,10 +156,12 @@ def test_write_videos(self): self.assertEqual( sorted_artifacts_videos[0].path, "videos/noise_1_000000000.mp4" ) + self.assertFalse(sorted_artifacts_videos[0].is_dir) artifacts_zzz = writer._client.list_artifacts(run.info.run_id, "videos/zzz") self.assertEqual(len(artifacts_zzz), 1) self.assertEqual(artifacts_zzz[0].path, "videos/zzz/noise_0_000000000.mp4") + self.assertFalse(artifacts_zzz[0].is_dir) def test_no_ops(self): with tempfile.TemporaryDirectory() as temp_dir: