8000 Results do not match the reference. This is likely a bug/unexpected loss of precision · Issue #24909 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Results do not match the reference. This is likely a bug/unexpected loss of precision #24909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yanboyang97 opened this issue Nov 15, 2024 · 8 comments
Labels
bug Something isn't working

Comments

@yanboyang97
Copy link

Description

import jax.numpy as jnp
import jax
import flax.linen as nn
from collections.abc import Callable
import time

class rbfnet(nn.Module):

    pointnums: int = 100
    areafun: Callable = nn.silu
    init_value: jax.Array = jnp.ones((100, 3))

    def setup(self):
        self.position = self.param('position', lambda rng, init_value: init_value, self.init_value) # (pointnums, dim)
        self.weight = self.param('weight', nn.initializers.zeros, (self.pointnums,)) # (pointnums)
        
    def __call__(self, x: jax.Array):
        batch = x.shape[0]
        x = jnp.expand_dims(x, axis=1).repeat(self.pointnums, axis=1) # (batch, pointnums, dim)
        # print(x.shape)
        position = jnp.expand_dims(self.position, axis=0).repeat(batch, axis=0) # (batch, pointnums, dim)
        position = self.areafun(position)
        # print(position.shape)
        distance = jnp.linalg.norm(x - position, ord=2, axis=-1) # (batch, pointnums)
        output =(1 / distance) @ self.weight
        return output

def main():
    x = jnp.linspace(2, 3, 50)
    y = jnp.linspace(2, 3, 50)
    z = jnp.linspace(2, 3, 50)
    X, Y, Z = jnp.meshgrid(x, y, z, indexing='ij')
    data = jnp.stack([X.reshape(-1), Y.reshape(-1), Z.reshape(-1)], axis=-1)
    # print(data.shape)

    model = rbfnet()
    variables = model.init(jax.random.key(0), data)
    # print(variables)

    @jax.jit
    def forward_and_backward(variables, x):
        # Compute the forward pass
        def loss_fn(variables, x):
            return jnp.mean(model.apply(variables, x))
        loss = loss_fn(variables, x)
        # Compute gradients
        grads = jax.grad(loss_fn)(variables, x)
        return loss, grads

    loss, grads = forward_and_backward(variables, data)

if __name__ == '__main__':
    main()

get these errors after running the code above:

2024-11-15 10:47:13.733095: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E1115 10:47:13.735496    3712 buffer_comparator.cc:157] Difference at 0: inf, expected 2.00886
E1115 10:47:13.735510    3712 buffer_comparator.cc:157] Difference at 1: inf, expected 2.01074
E1115 10:47:13.735515    3712 buffer_comparator.cc:157] Difference at 2: inf, expected 2.01014
E1115 10:47:13.735519    3712 buffer_comparator.cc:157] Difference at 3: inf, expected 2.00902
E1115 10:47:13.735523    3712 buffer_comparator.cc:157] Difference at 4: inf, expected 2.01255
E1115 10:47:13.735539    3712 buffer_comparator.cc:157] Difference at 5: inf, expected 2.00876
E1115 10:47:13.735543    3712 buffer_comparator.cc:157] Difference at 6: inf, expected 2.01238
E1115 10:47:13.735546    3712 buffer_comparator.cc:157] Difference at 7: inf, expected 2.00943
E1115 10:47:13.735550    3712 buffer_comparator.cc:157] Difference at 8: inf, expected 2.01083
E1115 10:47:13.735554    3712 buffer_comparator.cc:157] Difference at 9: inf, expected 2.01033

too many to display all the errors, but other errors were same like these

however, when I changed the sizes of x, y, z smaller like:

    x = jnp.linspace(2, 3, 5)
    y = jnp.linspace(2, 3, 5)
    z = jnp.linspace(2, 3, 5)

the errors disappeared

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='eda', release='3.10.0-1160.el7.x86_64', version='#1 SMP Mon Oct 19 16:18:59 UTC 2020', machine='x86_64')


$ nvidia-smi
Fri Nov 15 11:14:06 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:01:00.0 Off |                  Off |
|  0%   42C    P2              34W / 450W |  18657MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      6628      G   /usr/bin/X                                   42MiB |
|    0   N/A  N/A     12953      C   ...y/anaconda3/envs/jax_kan/bin/python    18554MiB |
|    0   N/A  N/A     79518      G   /usr/bin/gnome-shell                         38MiB |
+---------------------------------------------------------------------------------------+
@yanboyang97 yanboyang97 added the bug Something isn't working label Nov 15, 2024
@guilherme-salome
Copy link

I also ran into this error and it seems to be related to the sample size used during training.

@AaronSpieler
Copy link
AaronSpieler commented Feb 8, 2025

I run into the same error message executing a custom model. However, it only occurs if;

  1. I dont set jax.config.update("jax_default_matmul_precision", "float32")
  2. and vmap the forward of the component along a batch dimension.
  3. its executed on a GPU (cpu doesnt seem to be affected)

Is this also for anybody the same problem?

I narrowed the problem down to a single matmul in submodul component, however, when I run the sub-component seperately there is no problem, so I cannot isolate the problem rlly.

SYSTEM INFO:
OS: x86_64 GNU/Linux
GPU: NVIDIA A100-SXM4-40GB
NVIDIA-SMI 535.216.03
Driver Version: 535.216.03
CUDA Version: 12.2

CONDA ENV:

channels:
  - nvidia
  - pytorch
  - conda-forge
dependencies:
  - python=3.12
  - pip=24.3
  - setuptools=75.6
  - numpy=2.2
  - scipy=1.14
  - scikit-learn=1.6
  - pytorch=2.5.1=*cpu*
  - torchinfo=1.8
  - jax=0.4.35
  - jaxlib=0.4.35=*cuda126*
  - optax=0.2.3
  - equinox=0.11.10
  - jaxtyping=0.2
  - chex=0.1
  - hydra-core=1.3
  - tqdm=4.67
  - pandas=2.2
  - matplotlib=3.9
  - seaborn=0.13
  - jupyterlab=4.3
  - ipywidgets=8.1
  - h5py=3.12
  - urllib3=2.2
  - pytest=8.3
  - pre-commit=4.0
  - wandb=0.19
  - kaggle=1.6

Interestingly same jax version as OP.
So now I also have an minimal example recreating the error.

CODE:

import equinox as eqx
import jax
import jax.numpy as jnp

batch_size = 32 # => increased size leads to error
num_mlps = 10 # => doesn't seem to matter

# setting matmul precision can prevent it
#jax.config.update("jax_default_matmul_precision", "float32")

def create_mlp(mlp_key):
    mlp = eqx.nn.MLP(
        in_size=100,
        out_size=10,
        width_size=50,
        depth=1,
        activation=jax.nn.relu,
        final_activation=lambda x: x,
        use_bias=True,
        use_final_bias=True,
        key=mlp_key,
    )
    return mlp

#@eqx.filter_jit # => jit-ing leads to error
def infer_mlp(mlps, data):
    return eqx.filter_vmap(lambda mlp, x: mlp(x))(mlps, data)

@eqx.filter_jit # => jit-ing leads to error
def infer_batch_mlps(mlps, data):
    return eqx.filter_vmap(lambda x: infer_mlp(mlps, x))(data)

key = jax.random.PRNGKey(0)
mlp_key, input_key = jax.random.split(key)
mlp_keys = jax.random.split(mlp_key, num_mlps)

mlps = eqx.filter_vmap(create_mlp)(mlp_keys)
data = jax.random.normal(input_key, (batch_size, num_mlps, 100))
output = infer_batch_mlps(mlps, data)

print(f"Output shape: {output.shape}")

OUTPUT:

Output shape: (32, 10, 10)
E0208 13:42:49.742799   82190 buffer_comparator.cc:157] Difference at 16: 0, expected 31.792
E0208 13:42:49.742839   82190 buffer_comparator.cc:157] Difference at 17: 0, expected 32.434
E0208 13:42:49.742842   82190 buffer_comparator.cc:157] Difference at 18: 0, expected 31.5442
E0208 13:42:49.742845   82190 buffer_comparator.cc:157] Difference at 19: 0, expected 32.2899
E0208 13:42:49.742847   82190 buffer_comparator.cc:157] Difference at 20: 0, expected 31.9846
E0208 13:42:49.742850   82190 buffer_comparator.cc:157] Difference at 21: 0, expected 31.2843
E0208 13:42:49.742852   82190 buffer_comparator.cc:157] Difference at 22: 0, expected 31.597
E0208 13:42:49.742855   82190 buffer_comparator.cc:157] Difference at 23: 0, expected 31.9733
E0208 13:42:49.742857   82190 buffer_comparator.cc:157] Difference at 24: 0, expected 32.4919
E0208 13:42:49.742860   82190 buffer_comparator.cc:157] Difference at 25: 0, expected 28.697
2025-02-08 13:42:49.742874: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0208 13:42:49.743367   82190 buffer_comparator.cc:157] Difference at 16: 0, expected 31.792
E0208 13:42:49.743380   82190 buffer_comparator.cc:157] Difference at 17: 0, expected 32.434
E0208 13:42:49.743383   82190 buffer_comparator.cc:157] Difference at 18: 0, expected 31.5442
E0208 13:42:49.743385   82190 buffer_comparator.cc:157] Difference at 19: 0, expected 32.2899
E0208 13:42:49.743388   82190 buffer_comparator.cc:157] Difference at 20: 0, expected 31.9846
E0208 13:42:49.743390   82190 buffer_comparator.cc:157] Difference at 21: 0, expected 31.2843
E0208 13:42:49.743393   82190 buffer_comparator.cc:157] Difference at 22: 0, expected 31.597
E0208 13:42:49.743395   82190 buffer_comparator.cc:157] Difference at 23: 0, expected 31.9733
E0208 13:42:49.743397   82190 buffer_comparator.cc:157] Difference at 24: 0, expected 32.4919
E0208 13:42:49.743400   82190 buffer_comparator.cc:157] Difference at 25: 0, expected 28.697
2025-02-08 13:42:49.743404: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1180] Results do not match the reference. This is likely a bug/unexpected loss of precision.
```

@AaronSpieler
Copy link
AaronSpieler commented Feb 8, 2025

I believe it might be related to this issue in jax versions 0.4.33-0.4.35; #24843
However, setting the matmul precision doesnt fix that issue, so this would indicate this is different issue instead.
For me the error vanishes on version 0.4.28, with otherwise as close as possible conda environment.

@AaronSpieler
Copy link

The problem also occurs in never versions, specifically for jax & jaxlib versions 0.5.1.

@AaronSpieler
Copy link

Here is a minimal reproduction with vanilla jax:

import jax
import jax.numpy as jnp

# setting matmul precision prevents the error
#jax.config.update("jax_default_matmul_precision", "float32")

batch_size = 32  # => increased size leads to error
num_mlps = 10    # => doesn't seem to matter

def init_mlp(rng, in_size=100, hidden_size=50, out_size=10):
    """Initialize parameters for a 2-layer MLP."""
    rng_w1, rng_b1, rng_w2, rng_b2 = jax.random.split(rng, 4)
    W1 = jax.random.normal(rng_w1, (in_size, hidden_size))
    b1 = jax.random.normal(rng_b1, (hidden_size,))
    W2 = jax.random.normal(rng_w2, (hidden_size, out_size))
    b2 = jax.random.normal(rng_b2, (out_size,))
    return (W1, b1, W2, b2)

def forward_mlp(params, x):
    """Single forward pass for the MLP."""
    W1, b1, W2, b2 = params
    x = jnp.dot(x, W1) + b1
    x = jax.nn.relu(x)
    x = jnp.dot(x, W2) + b2
    return x

#@jax.jit # => jit-ing leads to error
def infer_mlp(mlps, x):
    """
    Apply each MLP in `mlps` to a corresponding data vector in `x`.
    mlps: (num_mlps,) pytree of parameters
    x:    (num_mlps, in_size)
    """
    return jax.vmap(forward_mlp, in_axes=(0, 0))(mlps, x)

@jax.jit # => jit-ing leads to error
def infer_batch_mlps(mlps, data):
    """
    For each batch element, call `infer_mlp(mlps, x)`.
    data: (batch_size, num_mlps, in_size)
    """
    return jax.vmap(lambda x: infer_mlp(mlps, x))(data)

# --- Main script ---
key = jax.random.PRNGKey(0)

# Create multiple MLP parameters
mlp_keys = jax.random.split(key, num_mlps)
mlps = jax.vmap(init_mlp)(mlp_keys)  # shape: (num_mlps,) of parameter pytree

# Create random input data
data_key = jax.random.PRNGKey(1)
data = jax.random.normal(data_key, (batch_size, num_mlps, 100))

# Run inference
output = infer_batch_mlps(mlps, data)
print("Output shape:", output.shape)

Could you maybe take a look @jakevdp or @dfm ?

@AaronSpieler
Copy link

I tried the above code on google colab with "v2-8 TPU" and "T4 GPU", and it doesnt show up on jax 0.5.1 and 0.4.33.

Furthermore, I printed the hlo stack like so;

hlo_stuff =  infer_batch_mlps.lower(mlps, data).compile().runtime_executable().hlo_modules()[0].to_string()
print(hlo_stuff)

And noticed, that only on the TPU is the code purely in "bf16", however, on the A100 (and also T4) even if model weights and precision is set to bf16 it still performs many "fp32" operations. Is this expected behavior on the A100 as well

@njzjz
Copy link
njzjz commented May 7, 2025

I got these messages on the A100 card, and jax.config.update("jax_default_matmul_precision", "float32") does work for me. What is the default behavior for jax_default_matmul_precision?

njzjz added a commit to njzjz/deepmd-kit that referenced this issue May 7, 2025
See jax-ml/jax#24909. Without setting this flag, the precision will become very low.
See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
njzjz added a commit to njzjz/deepmd-kit that referenced this issue May 7, 2025
See jax-ml/jax#24909. Without setting this flag, the precision will become very low, which seems to be a bug.
See https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html for what this option is.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
@njzjz
Copy link
njzjz commented May 7, 2025

I don't get these messages even if I set default_matmul_precision to tensorfloat32.

github-merge-queue bot pushed a commit to deepmodeling/deepmd-kit that referenced this issue May 8, 2025
See jax-ml/jax#24909. Without setting this
flag, the precision will become very low, which seems to be a bug (the
documentation says the GPU uses tensorfloat32 or float32, but the
default behavior seems wrong...).
See
https://docs.jax.dev/en/latest/_autosummary/jax.default_matmul_precision.html
for what this option is.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Chores**
- Updated environment configuration to set default matrix multiplication
precision to "tensorfloat32" for improved performance with JAX.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants
0