-
Notifications
You must be signed in to change notification settings - Fork 3k
Results do not match the reference. This is likely a bug/unexpected loss of precision #24909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I also ran into this error and it seems to be related to the sample size used during training. |
I run into the same error message executing a custom model. However, it only occurs if;
Is this also for anybody the same problem? I narrowed the problem down to a single matmul in submodul component, however, when I run the sub-component seperately there is no problem, so I cannot isolate the problem rlly. SYSTEM INFO: CONDA ENV:
Interestingly same jax version as OP. CODE:
OUTPUT:
|
I believe it might be related to this issue in jax versions 0.4.33-0.4.35; #24843 |
The problem also occurs in never versions, specifically for jax & jaxlib versions 0.5.1. |
Here is a minimal reproduction with vanilla jax:
|
I tried the above code on google colab with "v2-8 TPU" and "T4 GPU", and it doesnt show up on jax 0.5.1 and 0.4.33. Furthermore, I printed the hlo stack like so;
And noticed, that only on the TPU is the code purely in "bf16", however, on the A100 (and also T4) even if model weights and precision is set to bf16 it still performs many "fp32" operations. Is this expected behavior on the A100 as well |
I got these messages on the A100 card, and |
See jax-ml/jax#24909. Without setting this flag, the precision will become very low. See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
See jax-ml/jax#24909. Without setting this flag, the precision will become very low, which seems to be a bug. See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
I don't get these messages even if I set |
See jax-ml/jax#24909. Without setting this flag, the precision will become very low, which seems to be a bug (the documentation says the GPU uses tensorfloat32 or float32, but the default behavior seems wrong...). See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Chores** - Updated environment configuration to set default matrix multiplication precision to "tensorfloat32" for improved performance with JAX. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Description
get these errors after running the code above:
too many to display all the errors, but other errors were same like these
however, when I changed the sizes of x, y, z smaller like:
the errors disappeared
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: