8000 Minimal changes for mlx (and jax) compatibility for Image layers by jni · Pull Request #6553 · napari/napari · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Minimal changes for mlx (and jax) compatibility for Image layers #6553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 28, 2023

Conversation

jni
Copy link
Member
@jni jni commented Dec 24, 2023

These two changes are sufficient to allow napari.imshow to work with arrays
from mlx, Apple's new array library:

  1. The function range_to_decimals had a tiny hardcoded hack for tensorstore
    array dtype. But we have since then implemented the normalize_dtype function
    that works with a much broader range of arrays. In this PR, we change to use
    normalize_dtype, which works out of the box with mlx arrays.
  2. At one point we use Image.data.shape to try to make a set. It turns out
    that mlx arrays return a list, rather than a tuple, when checking the shape, so
    this PR coerces .shape to a tuple to make sure things are working.

I don't want to add an mlx test dependency because it's such a young library,
but these changes are very inoffensive so I hope we can just merge them. You
can test that this works locally with:

import napari
import mlx.core as mx


ran = mx.random.uniform(shape=(5, 10, 10))

viewer, layer = napari.imshow(ran)

if __name__ == '__main__':
    napari.run()

This works great with these changes! 😊

@github-actions github-actions bot added the qt Relates to qt label Dec 24, 2023
@jni
Copy link
Member Author
jni commented Dec 24, 2023

Bonus points: this works fine too. 😃

import napari
import mlx.core as mx
import numpy as np

import jax.numpy as jnp

ran = mx.random.uniform(shape=(5, 10, 10))
jan = jnp.array(np.asarray(ran))

viewer, layer = napari.imshow(ran)
layer2 = viewer.add_image(jan)

if __name__ == '__main__':
    napari.run()

@jni jni changed the title Minimal changes for mlx compatibility Minimal changes for mlx (and jax) compatibility for Image layers Dec 24, 2023
Copy link
codecov bot commented Dec 24, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (d507dd6) 92.29% compared to head (1c055c5) 92.25%.
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #6553      +/-   ##
==========================================
- Coverage   92.29%   92.25%   -0.05%     
==========================================
  Files         601      601              
  Lines       53687    53709      +22     
==========================================
- Hits        49553    49550       -3     
- Misses       4134     4159      +25     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sorry, something went wrong.

@@ -122,7 +122,7 @@ def _active_shape(s: LayerSel) -> Optional[Tuple[int, ...]]:


def _same_shape(s: LayerSel) -> bool:
return len({getattr(x.data, "shape", ()) for x in s}) == 1
return len({tuple(getattr(x.data, "shape", ())) for x in s}) == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment with an explanation why tuple is required. What type has mlx or jax shape

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I've added a docstring with a corresponding note.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which libraries return list?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only mlx that I know of. I also just saw that this is actually against the array-API spec, which specifies that the .shape attribute must be immutable. However, this is a low-cost modification to add compatibility with current library versions, so I suggest leaving this in. I'll add a comment that we can remove the tuple cast once ~all "major" array libraries return a tuple.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder why they did that--maybe worth opening an issue?

@Czaki Czaki added the maintenance PR with maintance changes, label Dec 25, 2023
@Czaki Czaki added this to the 0.5.0 milestone Dec 25, 2023
@Czaki Czaki added the ready to merge Last chance for comments! Will be merged in ~24h label Dec 26, 2023
@psobolewskiPhD
Copy link
Member

Cool! works nicely!
For a few simple things I did import mlx.core as np and it also worked.
(notably not random.random though)

@jni jni merged commit 8b10df9 into napari:main Dec 28, 2023
@jni jni deleted the dtypes-from-anyone branch December 28, 2023 00:27
@jni
Copy link
Member Author
jni commented Dec 28, 2023

@psobolewskiPhD I'm really excited to reimplement my napariboard demo with shared memory working on M-series machines, and without needing to wrap pytorch tensors in dask + .cpu().detach() calls. 😃 (I of course would prefer an open cross-platform standard, but we don't live in that timeline. 😭)

@Czaki Czaki removed the ready to merge Last chance for comments! Will be merged in ~24h label Dec 28, 2023
@psobolewskiPhD
Copy link
Member

cool! would be neat if the visualization side (e.g. vispy et al) could access those arrays...

kne42 added a commit to kne42/napari that referenced this pull request Jan 4, 2024
* main:
  Use blobs instead of random integers (napari#6527)
  Set default dtype for empty `_ImageSliceResponse` (napari#6552)
  Add creating image from clipboard (napari#6532)
  Fix check if plugin is available from conda (napari#6545)
  Fix generation of layer creation functions docstrings (napari#6558)
  Print a whole stack if throttler is not finished (napari#6549)
  Update `app-model`, `babel`, `coverage`, `dask`, `fsspec`, `hypothesis`, `imageio`, `ipython`, `lxml`, `magicgui`, `pandas`, `pint`, `psutil`, `pydantic`, `pyqt6`, `pytest-qt`, `tensorstore`, `tifffile`, `virtualenv`, `xarray` (napari#6478)
  Minimal changes for mlx (and jax) compatibility for Image layers (napari#6553)
  [pre-commit.ci] pre-commit autoupdate (napari#6528)
  Moving IntensityVisualizationMixin from _ImageBase to Image (napari#6548)
@psobolewskiPhD
Copy link
Member
psobolewskiPhD commented Apr 23, 2024

I case anyone hits this: with this PR you can also use torch tensors:

In [1]: import torch
In [2]: import napari

In [3]: tens = torch.tensor([[1, 2], [3, 4]])

In [4]: viewer = napari.Viewer()

In [5]: viewer.add_image(tens)
Out[5]: <Image layer 'tens' at 0x105206910>

🎉

In 0.4.19 you get an error TypeError: Cannot interpret 'torch.int64' as a data type

@jni
Copy link
Member Author
jni commented Apr 28, 2024

To be clear, that only works with plain torch tensors. If you have gradients or they are on the GPU, then you have extra work to do. See pytorch/pytorch#36560. I promised to make a NEP for force=True for numpy.asarray/__array__ methods, but didn't get around to it. 😅 Still, the second best time is now if someone wants to join me on that endeavour.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
maintenance PR with maintance changes, qt Relates to qt
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0