8000 video: fix bugs by garymm · Pull Request #20 · Astera-org/jax_loop_utils · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

video: fix bugs #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/jax_loop_utils/metric_writers/_audio_video/audio_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,44 @@ def encode_video(video_array: Array, destination: io.IOBase):
Args:
video_array: array to encode. Must have shape (T, H, W, 1) or (T, H, W, 3),
where T is the number of frames, H is the height, W is the width, and the last
dimension is the number of channels. Must have dtype uint8.
dimension is the number of channels.
Must be ints in [0, 255] or floats in [0, 1].
destination: Destination to write the encoded video.
"""
video_array = np.array(video_array)
if video_array.ndim != 4 or video_array.shape[-1] not in (1, 3):
raise ValueError(
"Expected an array with shape (T, H, W, 1) or (T, H, W, 3)."
f"Got shape {video_array.shape} with dtype {video_array.dtype}."
)

if (
video_array.dtype != np.uint8
or video_array.ndim != 4
or video_array.shape[-1] not in (1, 3)
np.issubdtype(video_array.dtype, np.floating)
and np.all(0 <= video_array)
and np.all(video_array <= 1.0)
):
video_array = (video_array * 255).astype(np.uint8)
elif (
np.issubdtype(video_array.dtype, np.integer)
and np.all(0 <= video_array)
and np.all(video_array <= 255)
):
video_array = video_array.astype(np.uint8)
else:
raise ValueError(
"Expected a uint8 array with shape (T, H, W, 1) or (T, H, W, 3)."
f"Got shape {video_array.shape} with dtype {video_array.dtype}."
f"Expected video_array to be floats in [0, 1] or ints in [0, 255], got {video_array.dtype}"
)

T, H, W, C = video_array.shape
# Pad height and width to even numbers if necessary
pad_h = H % 2
pad_w = W % 2
if pad_h or pad_w:
padding = [(0, 0), (0, pad_h), (0, pad_w), (0, 0)]
video_array = np.pad(video_array, padding, mode="constant")
H += pad_h
W += pad_w

is_grayscale = C == 1
if is_grayscale:
video_array = np.squeeze(video_array, axis=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@ class VideoTest(absltest.TestCase):
def test_encode_video_invalid_args(self):
"""Test that encode_video raises appropriate errors for invalid inputs."""
invalid_shape = np.zeros((10, 20, 30, 4), dtype=np.uint8)
with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"):
with self.assertRaisesRegex(ValueError, r"Expected an array with shape"):
encode_video(invalid_shape, io.BytesIO())

invalid_dtype = np.zeros((10, 20, 30, 3), dtype=np.float32)
with self.assertRaisesRegex(ValueError, "Expected a uint8 array with shape"):
invalid_dtype = 2 * np.ones((10, 20, 30, 3), dtype=np.float32)
with self.assertRaisesRegex(
ValueError, r"Expected video_array to be floats in \[0, 1\]"
):
encode_video(invalid_dtype, io.BytesIO())

def test_encode_video_success(self):
"""Test successful video encoding."""
# Create a simple test video - red square moving diagonally
T, H, W = 20, 64, 64
T, H, W = 20, 63, 63 # test non-even dimensions
video = np.zeros((T, H, W, 3), dtype=np.uint8)
for t in range(T):
pos = t * 5 # Move 5 pixels each frame
Expand All @@ -41,8 +43,8 @@ def test_encode_video_success(self):
output.seek(0)
with av.open(output, mode="r", format=CONTAINER_FORMAT) as container:
stream = container.streams.video[0]
self.assertEqual(stream.codec_context.width, W)
self.assertEqual(stream.codec_context.height, H)
self.assertEqual(stream.codec_context.width, W + 1)
self.assertEqual(stream.codec_context.height, H + 1)
self.assertEqual(stream.codec_context.framerate, FPS)
# Check we can decode all frames
frame_count = sum(1 for _ in container.decode(stream))
Expand Down
5 changes: 2 additions & 3 deletions src/jax_loop_utils/metric_writers/mlflow/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,11 @@ def _encode_and_log_video(
temp_path.parent.mkdir(parents=True, exist_ok=True)
with open(temp_path, "wb") as f:
_audio_video.encode_video(video_array, f) # pyright: ignore[reportOptionalMemberAccess]
dest_dir = os.path.join("videos", os.path.dirname(rel_path)).rstrip("/")
# If log_artifact(synchronous=False) existed,
# we could synchronize with self.flush() rather than at the end of write_videos.
# https://github.com/mlflow/mlflow/issues/14153
self._client.log_artifact(
self._run_id, temp_path, os.path.join("videos", rel_path)
)
self._client.log_artifact(self._run_id, temp_path, dest_dir)

def write_audios(self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
"""MLflow doesn't support audio logging directly."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ def test_write_videos(self):
self.assertEqual(
sorted_artifacts_videos[0].path, "videos/noise_1_000000000.mp4"
)
self.assertFalse(sorted_artifacts_videos[0].is_dir)

artifacts_zzz = writer._client.list_artifacts(run.info.run_id, "videos/zzz")
self.assertEqual(len(artifacts_zzz), 1)
self.assertEqual(artifacts_zzz[0].path, "videos/zzz/noise_0_000000000.mp4")
self.assertFalse(artifacts_zzz[0].is_dir)

def test_no_ops(self):
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
0