8000 [Feature request] unsupported int8 · Issue #28305 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Feature request] unsupported int8 #28305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
appujee opened this issue Apr 25, 2025 · 0 comments
Open

[Feature request] unsupported int8 #28305

appujee opened this issue Apr 25, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@appujee
Copy link
appujee commented Apr 25, 2025

Description

Trying to test test_add_constant with int8 fails. There is a # TODO(twsung): Add more types once lowering is implemented.. I'm wondering if there is a guidance on how to support a new type?

[  FAILED  ] OpsTest.test_add_constant3 (<class 'jax.numpy.int8'>)
======================================================================
ERROR: test_add_constant3 (<class 'jax.numpy.int8'>) (__main__.OpsTest)
OpsTest.test_add_constant3 (<class 'jax.numpy.int8'>)
test_add_constant(<class 'jax.numpy.int8'>)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/build/work/fbbed6b09c4a36081fe11060e1266d521c13/google3/runfiles/google3/third_party/py/jax/_src/compiler.py", line 324, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.jaxlib._jax.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'arith.addi'

at location: loc("/add"(callsite("OpsTest.test_add_constant.<locals>.kernel"("third_party/py/jax/tests/pallas/ops_test.py":983:19 to :33) at callsite("OpsTest.test_add_constant"("third_party/py/jax/tests/pallas/ops_test.py":986:8 to :45) at callsite("_ParameterizedTestIter.__iter__.<locals>.make_bound_param_test.<locals>.bound_param_test"("third_party/py/absl/testing/parameterized.py":325:17 to :51) at callsite("_run_and_get_tests_result"("third_party/py/absl/testing/absltest.py":2904:19 to :56) at callsite("run_tests"("third_party/py/absl/testing/absltest.py":2940:35 to 2942:3) at callsite("_run_in_app.<locals>.main_function"("third_party/py/absl/testing/absltest.py":2450:6 to :34) at callsite("_run_main"("third_party/py/absl/app.py":404:13 to :23) at callsite("run"("third_party/py/absl/app.py":484:6 to :27) at callsite("_run_in_app"("third_party/py/absl/testing/absltest.py":2452:4 to :31) at "main"("third_party/py/absl/testing/absltest.py":2334:2 to :38))))))))))))

The MLIR operation involved:
  %448 = "arith.addi"(%84, %67) <{overflowFlags = #arith.overflow<none>}> : (vector<8x128x4xi8>, vector<8x128x4xi8>) -> vector<8x128x4xi8>

System info (python version, jaxlib version, accelerator, etc.)

jax 0.6.0

@appujee appujee added the bug Something isn't working label Apr 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant
0