8000 `torch.linalg.solve` yields much lower precisions in `1.13.0` than previous versions · Issue #90453 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

torch.linalg.solve yields much lower precisions in 1.13.0 than previous versions #90453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
astroboylrx opened this issue Dec 8, 2022 · 11 comments
Labels
module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@astroboylrx
Copy link
astroboylrx commented Dec 8, 2022

🐛 Describe the bug

After upgrading to torch 1.13.0, torch.linalg.solve suddenly gives solutions with much lower precisions, regardless of device (cpu or gpu) or type (float64 or float32). The errors quickly escalate in my numerical calculations and break down my simulations.

Take the following data as an example (I know it is somewhat ill-conditioned, but the changes in behaviors are real)

import torch
torch.set_default_dtype(torch.float64)
torch.backends.cuda.matmul.allow_tf32 = False
A = torch.tensor([
    [ 3.8025705376834739e-07, -9.1719365342788720e-07, -6.7124337949782264e-06, -6.4837019110456791e-05, -7.0869999797614066e-04, -1.0694859984690733e-02, -3.2912231531790004e-01, -6.6347339870464399e+00, -8.2509761085708249e+01,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00,  4.4000124553730829e-07, -5.5080918253708871e-07, -5.1498277032055974e-06, -5.7818057148617599e-05, -9.1226448867859551e-04, -2.2619326362175465e-02, -4.4038788530099793e-01, -5.1992675801721502e+00,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00, -1.0669700681643825e-10,  4.3768558191229986e-07, -4.3974816153203019e-07, -4.8865127972067992e-06, -7.8116560507683326e-05, -1.7589402883070333e-03, -3.3666362131922367e-02, -3.8659142733749491e-01,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00, -7.8216940301197729e-12, -1.5895421888461478e-10,  4.3542984469163267e-07, -4.0043248885844276e-07, -6.6798905178796823e-06, -1.3761857019311234e-04, -2.5943507621790695e-03, -2.9003633389177604e-02,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00, -2.4603969583879200e-13, -6.0925772512004975e-12, -1.9886454656863128e-10,  4.3370279880257098e-07, -5.6639032522315289e-07, -1.0649799471193429e-05, -1.9808440853565822e-04, -2.1583707954594099e-03,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00, -1.4999959257460881e-15, -3.2831398418930186e-14, -8.8714562886788080e-13, -4.3280772005187299e-11,  4.4148762039828565e-07, -6.8089481270669943e-07, -1.4575015323337058e-05, -1.5597848962814291e-04,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00, -3.6858575028157790e-16, -7.2036090445864899e-15, -1.4349791509103240e-13, -2.9849302443991965e-12,  6.3914122655929791e-10,  4.6448551809896547e-07, -6.8453604307207769e-07, -1.0332761488908590e-05,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00, -3.7045642770024088e-17, -7.2015144280333478e-16, -1.4158860652466324e-14, -2.8662564585632735e-13, -6.2285079180541528e-12,  1.5090963357302090e-09,  4.8979817748389458e-07, -1.2863401745116974e-07,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00, -2.3760629594245614e-18, -4.6007155546998113e-17, -8.9513844792609796e-16, -1.7640414722799569e-14, -3.5935860384434572e-13, -7.9429359080595169e-12,  2.0146206213869421e-09,  4.7959403001188342e-07,  0.0000000000000000e+00],
    [ 0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  3.8025705376834739e-07]
])
b = torch.tensor(
    [ 6.9677181015078851e+04,  3.9337825712781823e+03,  2.7914109655787729e+02,  1.9895852311404216e+01,  1.3819016836738420e+00,  7.5229947004102571e-02,  1.3433804143281360e-03, -3.1421146091483441e-04, -2.8076324348838071e-05,  0.0000000000000000e+00]
)

With torch 1.12.1, the relative errors are around machine-precision (a few 1e-16), which is consistent with the precision obtained from numpy or cupy

In [1]: (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  0.0000000000000000e+00,  1.6068046486108669e-16,
        -3.6894317650011501e-16,  0.0000000000000000e+00, -0.0000000000000000e+00,  3.6202728109145290e-16,                     nan])

However, with torch 1.13.0, the relative errors are huge (max at 5e-11)

In [2]: (A @ torch.linalg.solve(A, b) - b) / b
tensor([-2.0884764590602007e-16,  4.6240212075443264e-16,  0.0000000000000000e+00, -1.7856554337026822e-16, -4.1776920863882539e-15,
        -8.7255061242277206e-14,  5.0944524844510106e-11, -2.0456409676328997e-11, -4.9269499441339466e-12,                     nan])

Below are more comparisons using torch.float64 and cuda

In [1]: A = torch.tensor([ ... ], device=torch.device('cuda'))
In [2]: b = torch.tensor([ ... ], device=torch.device('cuda'))
In [3]: (A @ torch.linalg.solve(A, b) - b) / b  # with torch 1.12.1
tensor([ 0.0000e+00,  1.1560e-16,  0.0000e+00,  0.0000e+00,  0.0000e+00,
        -1.8447e-16,  0.0000e+00,  1.7253e-16,  3.6203e-16,         nan],
       device='cuda:0')
In [4]: (A @ torch.linalg.solve(A, b) - b) / b  # with torch 1.13.0
tensor([-2.0885e-16,  0.0000e+00, -2.0364e-16, -7.1426e-16, -1.7675e-15,
         4.1875e-14,  4.4228e-11, -1.2897e-11, -3.0743e-12,         nan],
       device='cuda:0')

And more comparisons using torch.float32 and cpu

In [1]: torch.set_default_dtype(torch.float32)
In [2]: torch.backends.cuda.matmul.allow_tf32 = True
In [3]: (A @ torch.linalg.solve(A, b) - b) / b  # with torch 1.12.1
tensor([-1.1212e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  8.6265e-08,
         1.9807e-07, -8.6658e-08,  9.2625e-08, -0.0000e+00,         nan])
In [4]: (A @ torch.linalg.solve(A, b) - b) / b  # with torch 1.13.0
tensor([-1.1212e-07,  6.2063e-08, -1.0933e-07, -9.5867e-08, -2.3291e-06,
        -4.0902e-05, -2.2294e-02, -2.5929e-03, -1.9909e-03,         nan])

Versions

For tests with torch 1.12.1, the output is

Collecting environment information...
PyTorch version: 1.12.1+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.10 (default, Sep 28 2021, 16:10:42)  [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.15.79.1-microsoft-standard-WSL2-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU
Nvidia driver version: 527.37
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.5
[pip3] pytorch-memlab==0.2.4
[pip3] torch==1.12.1+cu116
[pip3] torchaudio==0.12.1+cu116
[pip3] torchvision==0.13.1+cu116
[pip3] xitorch==0.3.0
[conda] No relevant packages

For tests with torch 1.13.0, the output is

PyTorch version: 1.13.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.10 (default, Jun 22 2022, 20:18:18)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.79.1-microsoft-standard-WSL2-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU
Nvidia driver version: 527.37
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==1.13.0
[pip3] torchaudio==0.13.0
[pip3] torchvision==0.14.0
[conda] No relevant packages

cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano

@soumith
Copy link
Member
soumith commented Dec 8, 2022

cc: @lezcano who is a linalg guru

@lezcano
Copy link
Collaborator
lezcano commented Dec 8, 2022

That matrix is very badly conditioned:

>>> torch.linalg.svdvals(A)
tensor([8.2942e+01, 2.2421e-02, 9.6433e-07, 5.4287e-07, 4.8071e-07, 4.6181e-07,
        3.8026e-07, 2.1794e-0
8000
7, 2.7241e-12, 3.0295e-15])
>>> torch.linalg.cond(A)
tensor(2.7378e+16)

as such errors of that order are expected.
See https://pytorch.org/docs/master/notes/numerical_accuracy.html#extremal-values-in-linalg

@astroboylrx
Copy link
Author

That matrix is very badly conditioned
as such errors of that order are expected. See https://pytorch.org/docs/master/notes/numerical_accuracy.html#extremal-values-in-linalg

Yes, they are ill-conditioned but are helplessly, physically motivated. Do you mind elaborating a bit what have changed in 0.13.0? Are there some settings that I can use to make torch produce results similar to previous versions?

@astroboylrx
Copy link
Author

@lezcano I tried to find what have changed in the blog posts but didn't find any clue behind this change of behaviors.

With torch 1.12.1 or earlier, I can obtain the same answer in my test use cases from either torch, numpy, or cupy (including the example given above). And I've already spent quite some time to adapt my code from using numpy to torch to take advantage of the GPU acceleration.

Thus, I do wonder if this precision change is somewhat permanent or is there some hope on revival of old behaviors?

@astroboylrx
Copy link
Author
astroboylrx commented Dec 8, 2022

@lezcano I compiled PyTorch from source and can confirm that the commit 54949a5 is responsible for this behavior change. Here I want to argue that, even if the matrix is ill-conditioned, the solution should be around machine-precision (one may question the correctness of the solution, but that's another issue).


Using torch with one commit prior (65a3792) gives the expected results

In [4]: (A @ torch.linalg.solve(A, b) - b) / b
Out[4]: 
tensor([ 2.0885e-16,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.6068e-16,
        -3.6894e-16,  0.0000e+00, -0.0000e+00,  3.6203e-16,         nan])

Using torch with 54949a5 gives the new results

In [4]: (A @ torch.linalg.solve(A, b) - b) / b
Out[4]: 
tensor([-2.0885e-16, -6.9360e-16, -2.0364e-16,  0.0000e+00, -1.7675e-15,
        -5.3312e-13,  1.5038e-11,  <
8000
span class="pl-c1">1.3604e-11,  1.3794e-12,         nan])

However, my current knowledge is not enough to understand what had changed in the commit 54949a5. Any ideas/suggestions would be greatly appreciated!

@lezcano
Copy link
Collaborator
lezcano commented Dec 9, 2022

That commit is mostly a better engineering commit. Now, it introduces an optimisation. To avoid performing a copy of the matrix A, if A is C-contiguous (row-major), we factorise A^T instead and then call lu_solve with adjoint=True. It may be the case that the backend library (the relevant BLAS implementation that you use) is a bit less precise when called with adjoint=True.

In my opinion, errors of that order are reasonable for linear algebra. As discussed in the note linked in #90453 (comment), it is expected that different different devices and different backends give marginally different solutions. What I would suggest if you do care about these is that you try to find a small repro in C / C++ and you report the accuracy mismatch to the relevant BLAS implementation that you are using.

@astroboylrx
Copy link
Author
astroboylrx commented Dec 9, 2022

@lezcano Thanks a lot for the explanations. Sorry if I misunderstand something, but I am not using a specific BLAS implementation -- the error seems to be quite consistent across very different devices/platforms and BLAS backends (mkl/cuda/generic/open, see tests below). Is there any way that I may confirm the issue is due to some BLAS implementation? Thanks you again in advance.


1. The test results in the first comment of this issue were from WSL2 Ubuntu 20.04, a laptop with Intel i7-11800H + RTX 3070 Mobile, where the cpu results used MKL and the cuda results used (I think) cuBLAS (if not, I guess it is something in MAGMA). Below shows the config:

>>> print(torch.__config__.show())
PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.7
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
  - CuDNN 8.5
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.5.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

2. The test results in the sixth comment (where I found the exact commit responsible for this issue) were from a CentOS 7 server with Intel(R) Xeon(R) CPU E5-2650 v2, where I only (compiled and) used cpu with BLAS_INFO=generic and USE_MKL=OFF. Below shows the config:

>>> print(torch.__config__.show())
PyTorch built with:
  - GCC 8.3
  - C++ Version: 201402
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: NO AVX
  - Build settings: BLAS_INFO=generic, BUILD_TYPE=Release, CXX_COMPILER=/opt/ohpc/pub/compiler/gcc/8.3.0/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, FORCE_FALLBACK_CUDA_MPI=1, LAPACK_INFO=generic, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=ON, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

3. On a MacBook Pro with M1 Pro, I can also reproduce the errors (from either torch installed via pip or torch compiled by MacPorts)

>>> # ===== on CPU since MPS doesn't support linalg.solve yet =====
>>> (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 0.0000e+00,  2.3120e-16, -2.0364e-16, -1.7857e-16, -4.0170e-15,
        -2.1768e-14,  5.9627e-11, -1.7222e-11, -5.9982e-12,         nan])
>>> # ===== config =====
>>> print(torch.__config__.show())
PyTorch built with:
  - GCC 4.2
  - C++ Version: 201402
  - clang 14.0.0
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: NO AVX
  - Build settings: BLAS_INFO=accelerate, BUILD_TYPE=Release, CXX_COMPILER=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang++, CXX_FLAGS=-Wno-error=bitwise-instead-of-logical -fvisibility-inlines-hidden -Wno-deprecated-declarations -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_PYTORCH_METAL -DUSE_PYTORCH_METAL_EXPORT -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wvla-extension -Wno-range-loop-analysis -Wno-pass-failed -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -Wconstant-conversion -Wno-invalid-partial-specialization -Wno-typedef-redefinition -Wno-unused-private-field -Wno-inconsistent-missing-override -Wno-c++14-extensions -Wno-constexpr-not-const -Wno-missing-braces -Wunused-lambda-capture -Wunused-local-typedef -Qunused-arguments -fcolor-diagnostics -fdiagnostics-color=always -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -DUSE_MPS -fno-objc-arc -Wno-unguarded-availability-new -Wno-unused-private-field -Wno-missing-braces -Wno-c++14-extensions -Wno-constexpr-not-const, LAPACK_INFO=accelerate, TORCH_VERSION=1.13.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=ON, USE_GLOG=ON, USE_LITE_PROTO=1, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=OFF, USE_ROCM=OFF,

4. On a Ubuntu 20.04 server with ARM Neoverse-N1, I can also reproduce the errors with OpenBLAS

>>> # ===== also only on CPU =====
>>> (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 2.0885e-16, -1.1560e-16, -2.0364e-16, -1.0714e-15,  0.0000e+00,
        -5.1468e-14,  7.4451e-12,  1.6115e-11, -1.4156e-12,         nan])
>>> # ===== config (note that `BLAS_INFO=open`) =====
>>> print(torch.__config__.show())
PyTorch built with:
  - GCC 10.2
  - C++ Version: 201402
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: NO AVX
  - Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CXX_COMPILER=/opt/rh/devtoolset-10/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=open, TORCH_VERSION=1.13.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

5. On a CentOS 7.9 server but with AMD EPYC 7552 + Nvidia Tesla V100S-PCIE-32GB, I can also reproduce the errors

>>> # ===== on CPU =====
>>> (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 0.0000e+00, -6.9360e-16, -4.0727e-16,  5.3570e-16, -1.2854e-15,
        -3.3205e-13,  1.7105e-11,  1.4818e-11,  8.1637e-13,         nan])
>>> # ===== on GPU =====
>>> (A @ torch.linalg.solve(A, b) - b) / b
tensor([ 0.0000e+00, -2.3120e-16,  0.0000e+00, -3.5713e-16, -2.2495e-15,
         7.2682e-14,  5.9152e-11, -7.6224e-12, -4.7471e-12,         nan],
       device='cuda:0')
>>> # ===== config =====
>>> print(torch.__config__.show())
PyTorch built with:
  - GCC 9.4
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.4 Product Build 20200917 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash N/A)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.7
  - NVCC architecture flags: -gencode;arch=compute_52,code=sm_52;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_86,code=compute_86
  - CuDNN 8.4.1  (built against CUDA 11.6)
  - Magma 2.6.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.4.1, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS=-fno-gnu-unique -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.13.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=ON, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

I can do more tests on more platforms/backends if necessary (if these turn out to be useful tests).

@lezcano
Copy link
Collaborator
lezcano commented Dec 9, 2022

Thank you for the rather comprehensive dissection!

Alas, I still think that this error is coming from the backend implementation. As you can see in that PR, we dispatch to the same backends, only that in the initial PR, when the matrix A is row-major, we copied it into a column major matrix (which is slow) and then called the backend getrs with trans='N'. What that PR introduces is the optimisation of, rather than copying A into a column-major matrix, we simply call getrs with trans='T'. This is mathematically equivalent, but it's clear from your experiments that this codepath may not be as numerically stable when called with ill-conditioned matrices as the trans='N' path.

Again, here we are simply calling the given backends and perform a semantic-preserving optimisation as per the backend docs. I think this particular discrepancy should be reported to the relevant backends, given that you have a concrete reproducer.

Note that having less well-maintained paths in some of this libraries is not uncommon. We have found during the years many bugs in different libraries. For example, for MAGMA we discovered that the path trans='T'in getrs was buggy and I had to implement this option in terms of the trans='N' option in #77634. See in particular

// Computes X = U^{-1}L^{-1}P^T B via triangular solves
// Helps mitigating the bugs in magma
auto lu_solve_triangular = [n](const Tensor& LU, const Tensor& pivots, const Tensor& B, const TransposeType trans) {

@lezcano
Copy link
Collaborator
lezcano commented Dec 9, 2022

Yo can also see the bounds in the getrs docs about the guarantees given by the CPU backend. Given the condition number of the given matrix, I would not be surprised that the provided solutions are indeed within the range specified there.

Now, if you want to recover the previous implementation, what you can do is to bypass the optimisation that that PR added altogether by making the matrix column-major as per:

A = A.mT.contiguous().mT

You can see that by doing this you get the same results as with PyTorch 1.12. Now, as in PyTorch 1.12 (before that speed optimisation was implemented, you have to copy the data once.

@samdow samdow added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Dec 12, 2022
@astroboylrx
Copy link
Author

Thank you very much for the explanations and the workaround👍! My apologies for the late response.

Lesson learned that there are various codepaths maintained at different levels (indeed fascinating 😂).

@lezcano
Copy link
Collaborator
lezcano commented Dec 21, 2022

Closing this. Feel free to reopen if you have further issues.

@lezcano lezcano closed this as completed Dec 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0