Open
Description
Describe the bug
Attempting to use activation function with a cpu stream on float64 arrays fails with an unsupported constant type error.
To Reproduce
testing a scalar and non scalar mlx array, tested on mlx 0.25.2 and 0.26.1
import mlx.core as mx
import mlx.nn as nn
dtype = mx.float64
with mx.stream(mx.cpu):
x = mx.ones((), dtype=dtype)
try:
x = nn.celu(x)
except Exception as e:
print(f'(1) failed activation call with exception: {e}')
try:
mx.eval(x)
except Exception as e:
print(f'(1) failed eval with exception: {e}')
x = mx.ones((10), dtype=dtype)
try:
x = nn.celu(x)
except Exception as e:
print(f'(2) failed activation call with exception: {e}')
try:
mx.eval(x)
exce
519B
pt Exception as e:
print(f'(2) failed eval with exception: {e}')
# mlx==0.25.2
# (1) failed eval with exception: Unsupported constant type
# (2) failed eval with exception: Unsupported constant type
# mlx==0.26.1
# (1) failed activation call with exception: Unsupported constant type
# (2) failed activation call with exception: Unsupported constant type
Expected behavior
I'd expect these not to fail on the cpu stream.
Desktop (please complete the following information):
- OS Version: MacOS 15.2
- Version: 0.26.1 and 0.25.2
Additional context
Note that it fails in a different location for mlx 0.25.2 vs 0.26.1, and that this same error occurs across many of the activation functions (not just celu
as in the example)
(This occurred in tests for keras-team/keras#19571)
Metadata
Metadata
Assignees
Labels
No labels