From 9aeee9ea0b09ba097d890f925f3d0a42a2e338c8 Mon Sep 17 00:00:00 2001 From: Max Cembalest Date: Thu, 6 Mar 2025 22:15:57 -0500 Subject: [PATCH] excess image resize maintain aspect ratio --- nomic/dataset.py | 6 +++--- nomic/embed.py | 12 +----------- nomic/utils.py | 10 ++++++++++ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/nomic/dataset.py b/nomic/dataset.py index 05003842..40f415c7 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -35,7 +35,7 @@ ) from .data_operations import AtlasMapData, AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics from .settings import * -from .utils import assert_valid_project_id, download_feather +from .utils import assert_valid_project_id, download_feather, resize_pil class AtlasUser: @@ -1454,7 +1454,7 @@ def _add_blobs( image = Image.open(blob) image = image.convert("RGB") if image.height > 512 or image.width > 512: - image = image.resize((512, 512)) + image = resize_pil(image) buffered = BytesIO() image.save(buffered, format="JPEG") images.append((uuid, buffered.getvalue())) @@ -1463,7 +1463,7 @@ def _add_blobs( elif isinstance(blob, Image.Image): blob = blob.convert("RGB") # type: ignore if blob.height > 512 or blob.width > 512: - blob = blob.resize((512, 512)) + blob = resize_pil(blob) buffered = BytesIO() blob.save(buffered, format="JPEG") images.append((uuid, buffered.getvalue())) diff --git a/nomic/embed.py b/nomic/embed.py index c60df7db..3a2fa79a 100644 --- a/nomic/embed.py +++ b/nomic/embed.py @@ -16,6 +16,7 @@ from .dataset import AtlasClass from .settings import * +from .utils import resize_pil try: from gpt4all import CancellationError, Embed4All @@ -345,17 +346,6 @@ def image_api_request( raise Exception((response.status_code, response.text)) -def resize_pil(img): - width, height = img.size - # if image is too large, downsample before sending over the wire - max_width = 512 - max_height = 512 - if width > max_width or height > max_height: - downsize_factor = max(width // max_width, height // max_height) - img = img.resize((width // downsize_factor, height // downsize_factor)) - return img - - def _is_valid_url(url): if not isinstance(url, str): return False diff --git a/nomic/utils.py b/nomic/utils.py index 57137760..424b3817 100644 --- a/nomic/utils.py +++ b/nomic/utils.py @@ -288,3 +288,13 @@ def download_feather( if not download_success or schema is None: raise ValueError(f"Failed to download feather file from {url} after {num_attempts} attempts.") return schema + +def resize_pil(img): + width, height = img.size + # if image is too large, downsample before sending over the wire + max_width = 512 + max_height = 512 + if width > max_width or height > max_height: + downsize_factor = max(width // max_width, height // max_height) + img = img.resize((width // downsize_factor, height // downsize_factor)) + return img