You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Trying to test test_add_constant with int8 fails. There is a # TODO(twsung): Add more types once lowering is implemented.. I'm wondering if there is a guidance on how to support a new type?
[ FAILED ] OpsTest.test_add_constant3 (<class 'jax.numpy.int8'>)
======================================================================
ERROR: test_add_constant3 (<class 'jax.numpy.int8'>) (__main__.OpsTest)
OpsTest.test_add_constant3 (<class 'jax.numpy.int8'>)
test_add_constant(<class 'jax.numpy.int8'>)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/build/work/fbbed6b09c4a36081fe11060e1266d521c13/google3/runfiles/google3/third_party/py/jax/_src/compiler.py", line 324, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.jaxlib._jax.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'arith.addi'
at location: loc("/add"(callsite("OpsTest.test_add_constant.<locals>.kernel"("third_party/py/jax/tests/pallas/ops_test.py":983:19 to :33) at callsite("OpsTest.test_add_constant"("third_party/py/jax/tests/pallas/ops_test.py":986:8 to :45) at callsite("_ParameterizedTestIter.__iter__.<locals>.make_bound_param_test.<locals>.bound_param_test"("third_party/py/absl/testing/parameterized.py":325:17 to :51) at callsite("_run_and_get_tests_result"("third_party/py/absl/testing/absltest.py":2904:19 to :56) at callsite("run_tests"("third_party/py/absl/testing/absltest.py":2940:35 to 2942:3) at callsite("_run_in_app.<locals>.main_function"("third_party/py/absl/testing/absltest.py":2450:6 to :34) at callsite("_run_main"("third_party/py/absl/app.py":404:13 to :23) at callsite("run"("third_party/py/absl/app.py":484:6 to :27) at callsite("_run_in_app"("third_party/py/absl/testing/absltest.py":2452:4 to :31) at "main"("third_party/py/absl/testing/absltest.py":2334:2 to :38))))))))))))
The MLIR operation involved:
%448 = "arith.addi"(%84, %67) <{overflowFlags = #arith.overflow<none>}> : (vector<8x128x4xi8>, vector<8x128x4xi8>) -> vector<8x128x4xi8>
System info (python version, jaxlib version, accelerator, etc.)
jax 0.6.0
The text was updated successfully, but these errors were encountered:
Description
Trying to test test_add_constant with int8 fails. There is a
# TODO(twsung): Add more types once lowering is implemented.
. I'm wondering if there is a guidance on how to support a new type?System info (python version, jaxlib version, accelerator, etc.)
jax 0.6.0
The text was updated successfully, but these errors were encountered: