-
Notifications
You must be signed in to change notification settings - Fork 24.5k
torch.linalg.solve
yields much lower precisions in 1.13.0
than previous versions
#90453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc: @lezcano who is a linalg guru |
That matrix is very badly conditioned: >>> torch.linalg.svdvals(A)
tensor([8.2942e+01, 2.2421e-02, 9.6433e-07, 5.4287e-07, 4.8071e-07, 4.6181e-07,
3.8026e-07, 2.1794e-0
8000
7, 2.7241e-12, 3.0295e-15])
>>> torch.linalg.cond(A)
tensor(2.7378e+16) as such errors of that order are expected. |
Yes, they are ill-conditioned but are helplessly, physically motivated. Do you mind elaborating a bit what have changed in |
@lezcano I tried to find what have changed in the blog posts but didn't find any clue behind this change of behaviors. With Thus, I do wonder if this precision change is somewhat permanent or is there some hope on revival of old behaviors? |
@lezcano I compiled PyTorch from source and can confirm that the commit 54949a5 is responsible for this behavior change. Here I want to argue that, even if the matrix is ill-conditioned, the solution should be around machine-precision (one may question the correctness of the solution, but that's another issue). Using In [4]: (A @ torch.linalg.solve(A, b) - b) / b
Out[4]:
tensor([ 2.0885e-16, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.6068e-16,
-3.6894e-16, 0.0000e+00, -0.0000e+00, 3.6203e-16, nan]) Using In [4]: (A @ torch.linalg.solve(A, b) - b) / b
Out[4]:
tensor([-2.0885e-16, -6.9360e-16, -2.0364e-16, 0.0000e+00, -1.7675e-15,
-5.3312e-13, 1.5038e-11, <
8000
span class="pl-c1">1.3604e-11, 1.3794e-12, nan]) However, my current knowledge is not enough to understand what had changed in the commit 54949a5. Any ideas/suggestions would be greatly appreciated! |
That commit is mostly a better engineering commit. Now, it introduces an optimisation. To avoid performing a copy of the matrix In my opinion, errors of that order are reasonable for linear algebra. As discussed in the note linked in #90453 (comment), it is expected that different different devices and different backends give marginally different solutions. What I would suggest if you do care about these is that you try to find a small repro in C / C++ and you report the accuracy mismatch to the relevant BLAS implementation that you are using. |
@lezcano Thanks a lot for the explanations. Sorry if I misunderstand something, but I am not using a specific BLAS implementation -- the error seems to be quite consistent across very different devices/platforms and BLAS backends ( 1. The test results in the first comment of this issue were from WSL2 Ubuntu 20.04, a laptop with Intel i7-11800H + RTX 3070 Mobile, where the >>> print(torch.__config__.show())
PyTorch built with:
- GCC 9.3
- C++ Version: 201402
- Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 11.7
- NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
- CuDNN 8.5
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.5.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 2. The test results in the sixth comment (where I found the exact commit responsible for this issue) were from a CentOS 7 server with Intel(R) Xeon(R) CPU E5-2650 v2, where I only (compiled and) used >>> print(torch.__config__.show())
PyTorch built with:
- GCC 8.3
- C++ Version: 201402
- Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: NO AVX
- Build settings: BLAS_INFO=generic, BUILD_TYPE=Release, CXX_COMPILER=/opt/ohpc/pub/compiler/gcc/8.3.0/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, FORCE_FALLBACK_CUDA_MPI=1, LAPACK_INFO=generic, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=ON, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 3. On a MacBook Pro with M1 Pro, I can also reproduce the errors (from either >>> # ===== on CPU since MPS doesn't support linalg.solve yet =====
>>> (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 0.0000e+00, 2.3120e-16, -2.0364e-16, -1.7857e-16, -4.0170e-15,
-2.1768e-14, 5.9627e-11, -1.7222e-11, -5.9982e-12, nan])
>>> # ===== config =====
>>> print(torch.__config__.show())
PyTorch built with:
- GCC 4.2
- C++ Version: 201402
- clang 14.0.0
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: NO AVX
- Build settings: BLAS_INFO=accelerate, BUILD_TYPE=Release, CXX_COMPILER=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang++, CXX_FLAGS=-Wno-error=bitwise-instead-of-logical -fvisibility-inlines-hidden -Wno-deprecated-declarations -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_PYTORCH_METAL -DUSE_PYTORCH_METAL_EXPORT -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wvla-extension -Wno-range 4. On a Ubuntu 20.04 server with ARM Neoverse-N1, I can also reproduce the errors with >>> # ===== also only on CPU =====
>>> (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 2.0885e-16, -1.1560e-16, -2.0364e-16, -1.0714e-15, 0.0000e+00,
-5.1468e-14, 7.4451e-12, 1.6115e-11, -1.4156e-12, nan])
>>> # ===== config (note that `BLAS_INFO=open`) =====
>>> print(torch.__config__.show())
PyTorch built with:
- GCC 10.2
- C++ Version: 201402
- Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: NO AVX
- Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CXX_COMPILER=/opt/rh/devtoolset-10/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=open, TORCH_VERSION=1.13.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 5. On a CentOS 7.9 server but with AMD EPYC 7552 + Nvidia Tesla V100S-PCIE-32GB, I can also reproduce the errors >>> # ===== on CPU =====
>>> (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 0.0000e+00, -6.9360e-16, -4.0727e-16, 5.3570e-16, -1.2854e-15,
-3.3205e-13, 1.7105e-11, 1.4818e-11, 8.1637e-13, nan])
>>> # ===== on GPU =====
>>> (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 0.0000e+00, -2.3120e-16, 0.0000e+00, -3.5713e-16, -2.2495e-15,
7.2682e-14, 5.9152e-11, -7.6224e-12, -4.7471e-12, nan],
device='cuda:0')
>>> # ===== config =====
>>> print(torch.__config__.show())
PyTorch built with:
- GCC 9.4
- C++ Version: 201402
- Intel(R) Math Kernel Library Version 2020.0.4 Product Build 20200917 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v2.6.0 (Git Hash N/A)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 11.7
- NVCC architecture flags: -gencode;arch=compute_52,code=sm_52;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_86,code=compute_86
- CuDNN 8.4.1 (built against CUDA 11.6)
- Magma 2.6.2
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.4.1, CXX_COMPILER=/usr/bin/c+ I can do more tests on more platforms/backends if necessary (if these turn out to be useful tests). |
Thank you for the rather comprehensive dissection! Alas, I still think that this error is coming from the backend implementation. As you can see in that PR, we dispatch to the same backends, only that in the initial PR, when the matrix A is row-major, we copied it into a column major matrix (which is slow) and then called the backend Again, here we are simply calling the given backends and perform a semantic-preserving optimisation as per the backend docs. I think this particular discrepancy should be reported to the relevant backends, given that you have a concrete reproducer. Note that having less well-maintained paths in some of this libraries is not uncommon. We have found during the years many bugs in different libraries. For example, for MAGMA we discovered that the path pytorch/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp Lines 2460 to 2462 in 983d4f6
|
Yo can also see the bounds in the Now, if you want to recover the previous implementation, what you can do is to bypass the optimisation that that PR added altogether by making the matrix column-major as per: A = A.mT.contiguous().mT You can see that by doing this you get the same results as with PyTorch 1.12. Now, as in PyTorch 1.12 (before that speed optimisation was implemented, you have to copy the data once. |
Thank you very much for the explanations and the workaround👍! My apologies for the late response. Lesson learned that there are various codepaths maintained at different levels (indeed fascinating 😂). |
Closing this. Feel free to reopen if you have further issues. |
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
After upgrading to
torch 1.13.0
,torch.linalg.solve
suddenly gives solutions with much lower precisions, regardless of device (cpu
orgpu
) or type (float64
orfloat32
). The errors quickly escalate in my numerical calculations and break down my simulations.Take the following data as an example (I know it is somewhat ill-conditioned, but the changes in behaviors are real)
With
torch 1.12.1
, the relative errors are around machine-precision (a few 1e-16), which is consistent with the precision obtained fromnumpy
orcupy
However, with
torch 1.13.0
, the relative errors are huge (max at 5e-11)Below are more comparisons using
torch.float64
andcuda
And more comparisons using
torch.float32
andcpu
Versions
For tests with
torch 1.12.1
, the output isCollecting environment information... PyTorch version: 1.12.1+cu116 Is debug build: False CUDA used to build PyTorch: 11.6 ROCM used to build PyTorch: N/A OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: 10.0.0-4ubuntu1 CMake version: version 3.16.3 Libc version: glibc-2.31 Python version: 3.8.10 (default, Sep 28 2021, 16:10:42) [GCC 9.3.0] (64-bit runtime) Python platform: Linux-5.15.79.1-microsoft-standard-WSL2-x86_64-with-glibc2.29 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU Nvidia driver version: 527.37 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Versions of relevant libraries: [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.23.5 [pip3] pytorch-memlab==0.2.4 [pip3] torch==1.12.1+cu116 [pip3] torchaudio==0.12.1+cu116 [pip3] torchvision==0.13.1+cu116 [pip3] xitorch==0.3.0 [conda] No relevant packages
For tests with
torch 1.13.0
, the output isPyTorch version: 1.13.0+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: 10.0.0-4ubuntu1 CMake version: version 3.16.3 Libc version: glibc-2.31 Python version: 3.8.10 (default, Jun 22 2022, 20:18:18) [GCC 9.4.0] (64-bit runtime) Python platform: Linux-5.15.79.1-microsoft-standard-WSL2-x86_64-with-glibc2.29 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU Nvidia driver version: 527.37 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Versions of relevant libraries: [pip3] numpy==1.23.5 [pip3] torch==1.13.0 [pip3] torchaudio==0.13.0 [pip3] torchvision==0.14.0 [conda] No relevant packages
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano
The text was updated successfully, but these errors were encountered: