Open
Description
Description
Trying to test test_add_constant with int8 fails. There is a # TODO(twsung): Add more types once lowering is implemented.
. I'm wondering if there is a guidance on how to support a new type?
[ FAILED ] OpsTest.test_add_constant3 (<class 'jax.numpy.int8'>)
======================================================================
ERROR: test_add_constant3 (<class 'jax.numpy.int8'>) (__main__.OpsTest)
OpsTest.test_add_constant3 (<class 'jax.numpy.int8'>)
test_add_constant(<class 'jax.numpy.int8'>)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/build/work/fbbed6b09c4a36081fe11060e1266d521c13/google3/runfiles/google3/third_party/py/jax/_src/compiler.py", line 324, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.jaxlib._jax.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'arith.addi'
at location: loc("/add"(callsite("OpsTest.test_add_constant.<locals>.kernel"("third_party/py/jax/tests/pallas/ops_test.py":983:19 to :33) at callsite("OpsTest.test_add_constant"("third_party/py/jax/tests/pallas/ops_test.py":986:8 to :45) at callsite("_ParameterizedTestIter.__iter__.<locals>.make_bound_param_test.<locals>.bound_param_test"("third_party/py/absl/testing/parameterized.py":325:17 to :51) at callsite("_run_and_get_tests_result"("third_party/py/absl/testing/absltest.py":2904:19 to :56) at callsite("run_tests"("third_party/py/absl/testing/absltest.py":2940:35 to 2942:3) at callsite("_run_in_app.<locals>.main_function"("third_party/py/absl/testing/absltest.py":2450:6 to :34) at callsite("_run_main"("third_party/py/absl/app.py":404:13 to :23) at callsite("run"("third_party/py/absl/app.py":484:6 to :27) at callsite("_run_in_app"("third_party/py/absl/testing/absltest.py":2452:4 to :31) at "main"("third_party/py/absl/testing/absltest.py":2334:2 to :38))))))))))))
The MLIR operation involved:
%448 = "arith.addi"(%84, %67) <{overflowFlags = #arith.overflow<none>}> : (vector<8x128x4xi8>, vector<8x128x4xi8>) -> vector<8x128x4xi8>
System info (python version, jaxlib version, accelerator, etc.)
jax 0.6.0