-
-
Notifications
You must be signed in to change notification settings - Fork 445
Minimal changes for mlx (and jax) compatibility for Image layers #6553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Bonus points: this works fine too. 😃 import napari
import mlx.core as mx
import numpy as np
import jax.numpy as jnp
ran = mx.random.uniform(shape=(5, 10, 10))
jan = jnp.array(np.asarray(ran))
viewer, layer = napari.imshow(ran)
layer2 = viewer.add_image(jan)
if __name__ == '__main__':
napari.run() |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #6553 +/- ##
==========================================
- Coverage 92.29% 92.25% -0.05%
==========================================
Files 601 601
Lines 53687 53709 +22
==========================================
- Hits 49553 49550 -3
- Misses 4134 4159 +25 ☔ View full report in Codecov by Sentry. |
@@ -122,7 +122,7 @@ def _active_shape(s: LayerSel) -> Optional[Tuple[int, ...]]: | |||
|
|||
|
|||
def _same_shape(s: LayerSel) -> bool: | |||
return len({getattr(x.data, "shape", ()) for x in s}) == 1 | |||
return len({tuple(getattr(x.data, "shape", ())) for x in s}) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment with an explanation why tuple is required. What type has mlx
or jax
shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I've added a docstring with a corresponding note.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which libraries return list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only mlx that I know of. I also just saw that this is actually against the array-API spec, which specifies that the .shape attribute must be immutable. However, this is a low-cost modification to add compatibility with current library versions, so I suggest leaving this in. I'll add a comment that we can remove the tuple cast once ~all "major" array libraries return a tuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder why they did that--maybe worth opening an issue?
Cool! works nicely! |
@psobolewskiPhD I'm really excited to reimplement my napariboard demo with shared memory working on M-series machines, and without needing to wrap pytorch tensors in dask + |
cool! would be neat if the visualization side (e.g. vispy et al) could access those arrays... |
* main: Use blobs instead of random integers (napari#6527) Set default dtype for empty `_ImageSliceResponse` (napari#6552) Add creating image from clipboard (napari#6532) Fix check if plugin is available from conda (napari#6545) Fix generation of layer creation functions docstrings (napari#6558) Print a whole stack if throttler is not finished (napari#6549) Update `app-model`, `babel`, `coverage`, `dask`, `fsspec`, `hypothesis`, `imageio`, `ipython`, `lxml`, `magicgui`, `pandas`, `pint`, `psutil`, `pydantic`, `pyqt6`, `pytest-qt`, `tensorstore`, `tifffile`, `virtualenv`, `xarray` (napari#6478) Minimal changes for mlx (and jax) compatibility for Image layers (napari#6553) [pre-commit.ci] pre-commit autoupdate (napari#6528) Moving IntensityVisualizationMixin from _ImageBase to Image (napari#6548)
I case anyone hits this: with this PR you can also use torch tensors:
🎉 In 0.4.19 you get an error |
To be clear, that only works with plain torch tensors. If you have gradients or they are on the GPU, then you have extra work to do. See pytorch/pytorch#36560. I promised to make a NEP for |
These two changes are sufficient to allow napari.imshow to work with arrays
from mlx, Apple's new array library:
range_to_decimals
had a tiny hardcoded hack for tensorstorearray dtype. But we have since then implemented the normalize_dtype function
that works with a much broader range of arrays. In this PR, we change to use
normalize_dtype, which works out of the box with mlx arrays.
Image.data.shape
to try to make a set. It turns outthat mlx arrays return a list, rather than a tuple, when checking the shape, so
this PR coerces .shape to a tuple to make sure things are working.
I don't want to add an mlx test dependency because it's such a young library,
but these changes are very inoffensive so I hope we can just merge them. You
can test that this works locally with:
This works great with these changes! 😊