8000 [Feature request] unsupported int8 · Issue #28305 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[Feature request] unsupported int8 #28305
Open
@appujee

Description

@appujee

Description

Trying to test test_add_constant with int8 fails. There is a # TODO(twsung): Add more types once lowering is implemented.. I'm wondering if there is a guidance on how to support a new type?

[  FAILED  ] OpsTest.test_add_constant3 (<class 'jax.numpy.int8'>)
======================================================================
ERROR: test_add_constant3 (<class 'jax.numpy.int8'>) (__main__.OpsTest)
OpsTest.test_add_constant3 (<class 'jax.numpy.int8'>)
test_add_constant(<class 'jax.numpy.int8'>)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/build/work/fbbed6b09c4a36081fe11060e1266d521c13/google3/runfiles/google3/third_party/py/jax/_src/compiler.py", line 324, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.jaxlib._jax.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'arith.addi'

at location: loc("/add"(callsite("OpsTest.test_add_constant.<locals>.kernel"("third_party/py/jax/tests/pallas/ops_test.py":983:19 to :33) at callsite("OpsTest.test_add_constant"("third_party/py/jax/tests/pallas/ops_test.py":986:8 to :45) at callsite("_ParameterizedTestIter.__iter__.<locals>.make_bound_param_test.<locals>.bound_param_test"("third_party/py/absl/testing/parameterized.py":325:17 to :51) at callsite("_run_and_get_tests_result"("third_party/py/absl/testing/absltest.py":2904:19 to :56) at callsite("run_tests"("third_party/py/absl/testing/absltest.py":2940:35 to 2942:3) at callsite("_run_in_app.<locals>.main_function"("third_party/py/absl/testing/absltest.py":2450:6 to :34) at callsite("_run_main"("third_party/py/absl/app.py":404:13 to :23) at callsite("run"("third_party/py/absl/app.py":484:6 to :27) at callsite("_run_in_app"("third_party/py/absl/testing/absltest.py":2452:4 to :31) at "main"("third_party/py/absl/testing/absltest.py":2334:2 to :38))))))))))))

The MLIR operation involved:
  %448 = "arith.addi"(%84, %67) <{overflowFlags = #arith.overflow<none>}> : (vector<8x128x4xi8>, vector<8x128x4xi8>) -> vector<8x128x4xi8>

System info (python version, jaxlib version, accelerator, etc.)

jax 0.6.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0