From 028a04fbc284b17acc3d8905f14999c21f0f7139 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 20 Oct 2021 11:02:07 +0000 Subject: [PATCH 1/5] Add linalg.lu_factor This PR exposes `torch.lu` as `torch.linalg.lu_factor` and `torch.linalg.lu_factor_ex`. This PR also adds support for matrices with zero elements both in the size of the matrix and the batch. Note that this function simply returns empty tensors of the correct size in this case. We add a test and an OpInfo for the new function. This PR also adds documentation for this new function in line of the documentation in the rest of `torch.linalg`. [ghstack-poisoned] --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 120 +++++++++++++--- aten/src/ATen/native/BatchLinearAlgebra.h | 4 +- .../ATen/native/BatchLinearAlgebraKernel.cpp | 16 +-- aten/src/ATen/native/LinearAlgebraUtils.h | 14 ++ .../ATen/native/cuda/BatchLinearAlgebra.cpp | 54 ++++---- .../native/cuda/BatchLinearAlgebraLib.cpp | 58 +++----- .../ATen/native/cuda/BatchLinearAlgebraLib.h | 2 +- aten/src/ATen/native/native_functions.yaml | 21 +++ docs/source/linalg.rst | 2 + test/test_linalg.py | 61 ++++++++ tools/autograd/derivatives.yaml | 7 +- torch/csrc/api/include/torch/linalg.h | 19 +++ torch/csrc/autograd/FunctionsManual.cpp | 28 ++-- torch/csrc/autograd/FunctionsManual.h | 4 +- torch/linalg/__init__.py | 130 +++++++++++++++++- torch/overrides.py | 2 + .../_internal/common_methods_invocations.py | 12 ++ torch/testing/_internal/common_utils.py | 2 + 18 files changed, 433 insertions(+), 123 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index e3362f0d6b9f2..a80d7132c5a2e 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -889,7 +889,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor result = result.unsqueeze_(-1); } - // lu_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted' + // lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted' result.copy_(other_broadcasted); auto input_working_copy = cloneBatchedColumnMajor(input_broadcasted); @@ -906,7 +906,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor auto pivots_shape = IntArrayRef(input_broadcasted.sizes().data(), input_broadcasted.dim() - 2).vec(); // input_broadcasted.shape[:-2] pivots_shape.push_back(std::min(input.size(-2), input.size(-1))); Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt)); - lu_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true); + lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true); // solve the linear system using the LU factorization lu_solve_stub(input.device().type(), result, input_working_copy, pivots); @@ -1554,30 +1554,110 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) { return result; } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -DEFINE_DISPATCH(lu_stub); +DEFINE_DISPATCH(lu_factor_stub); -// TODO: remove check_errors argument -// https://github.com/pytorch/pytorch/issues/64014 -std::tuple _lu_with_info(const Tensor& self, bool compute_pivots) { - TORCH_CHECK(self.dim() >= 2, - "expected tensor with 2 or more dimensions, got size: ", self.sizes(), - " instead"); - auto m = self.size(-2); - auto n = self.size(-1); - auto req_size = self.sizes().vec(); +std::tuple linalg_lu_factor_ex_out(const Tensor& A, + bool pivot, + bool check_errors, + Tensor& LU, + Tensor& pivots, + Tensor& info) { + TORCH_CHECK(A.dim() >= 2, + "expected tensor with 2 or more dimensions, got size: ", A.sizes(), " instead"); + auto req_size = A.sizes().vec(); + // TODO reimplementation of resize_output with format F-contiguous + // We should make this a standalone function + if (resize_output_check(LU, req_size)) { + // Transpose size + std::iter_swap(req_size.end() - 1, req_size.end() - 2); + LU.resize_(req_size, MemoryFormat::Contiguous); + std::iter_swap(req_size.end() - 1, req_size.end() - 2); + LU.transpose_(-2, -1); // make 'LU' have Fortran contiguous memory layLU + } + const auto m = req_size.cend()[-2]; + const auto n = req_size.cend()[-1]; req_size.pop_back(); req_size.back() = std::min(m, n); - auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt)); + at::native::resize_output(pivots, req_size); req_size.pop_back(); - auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt)); + at::native::resize_output(info, req_size); + + const auto LU_f_contig = LU.transpose(-2, -1).is_contiguous() ; + + if (LU_f_contig && !LU.is_same(A)) { + LU.copy_(A); + } + const auto LU_ = borrow_else_clone(LU_f_contig, LU, A, /*C-contig*/false); + + const auto pivots_contig = pivots.is_contiguous(); + const auto pivots_ = borrow_else_clone(pivots_contig, pivots, pivots, /*C-contig*/true); + + const auto info_contig = info.is_contiguous(); + const auto info_ = borrow_else_clone(info_contig, info, info, /*C-contig*/true); + + + lu_factor_stub(A.device().type(), *LU_, *pivots_, *info_, pivot); + + if (!LU_f_contig) { + LU.copy_(*LU_); + } + if (!pivots_contig) { + pivots.copy_(*pivots_); + } + if (!info_contig) { + info.copy_(*info_); + } + + if (check_errors) { + if (A.dim() > 2) { + batchCheckErrors(info, "torch.linalg.lu_factor_ex"); + } else { + singleCheckErrors(info.item(), "torch.linalg.lu_factor_ex"); + } + } + + return {LU, pivots, info}; +} + +std::tuple linalg_lu_factor_ex(const Tensor& A, bool pivot, bool check_errors) { + auto LU = at::empty({0}, A.options()); + auto pivots = at::empty({0}, A.options().dtype(kInt)); + auto info = at::empty({0}, A.options().dtype(kInt)); + at::native::linalg_lu_factor_ex_out(A, pivot, check_errors, LU, pivots, info); + return {std::move(LU), std::move(pivots), std::move(info)}; +} + +std::tuple linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor & LU, Tensor & pivots) { + auto info = at::empty({0}, A.options().dtype(kInt)); + // We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors + at::native::linalg_lu_factor_ex_out(A, pivot, /*chech_errors=*/false, LU, pivots, info); + if (A.dim() > 2) { + batchCheckErrors(info, "torch.linalg.lu_factor"); + } else { + singleCheckErrors(info.item(), "torch.linalg.lu_factor"); + } + + return {LU, pivots}; +} + +std::tuple linalg_lu_factor(const Tensor& A, bool pivot) { + Tensor LU, pivots, info; + std::tie(LU, pivots, info) = at::linalg_lu_factor_ex(A, pivot, /*check_errors=*/false); + + if (A.dim() > 2) { + batchCheckErrors(info, "torch.linalg.lu_factor"); + } else { + singleCheckErrors(info.item(), "torch.linalg.lu_factor"); + } + + return {std::move(LU), std::move(pivots)}; +} - // lu_stub (apply_lu) requires batched column major (Fortran-contiguous) tensors - // 'lu' tensor is modified in-place and must be a copy of 'self' - Tensor lu = cloneBatchedColumnMajor(self); - lu_stub(self.device().type(), lu, pivots_tensor, infos_tensor, compute_pivots); - return std::make_tuple(lu, pivots_tensor, infos_tensor); +// TODO Deprecate this function in favour of linalg_lu_factor_ex +std::tuple _lu_with_info(const Tensor& self, bool pivot) { + return at::linalg_lu_factor_ex(self, pivot, /*check_errors=*/false); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index dba3a415cdc9e..8edcbc72b5f48 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -219,12 +219,12 @@ using triangular_solve_fn = void (*)( bool /*unitriangular*/); DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub); -using lu_fn = void (*)( +using lu_factor_fn = void (*)( const Tensor& /*input*/, const Tensor& /*pivots*/, const Tensor& /*infos*/, bool /*compute_pivots*/); -DECLARE_DISPATCH(lu_fn, lu_stub); +DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub); using lu_solve_fn = void (*)( const Tensor& /*b*/, diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index a910cf1fd46fc..84800995cc860 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -846,14 +846,14 @@ void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, Transp For further details, please see the LAPACK documentation for GETRF. */ template -void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { #if !AT_BUILD_WITH_LAPACK() TORCH_CHECK( false, "Calling torch.lu on a CPU tensor requires compiling ", "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); #else - TORCH_CHECK(compute_pivots, "lu without pivoting is not implemented on the CPU"); + TORCH_CHECK(compute_pivots, "lu_factor without pivoting is not implemented on the CPU"); auto input_data = input.data_ptr(); auto pivots_data = pivots.data_ptr(); @@ -875,9 +875,9 @@ void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bo } // This is a type dispatching helper function for 'apply_lu' -void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +void lu_factor_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_cpu", [&]{ - apply_lu(input, pivots, infos, compute_pivots); + apply_lu_factor(input, pivots, infos, compute_pivots); }); } @@ -994,10 +994,10 @@ REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); -REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel); -REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel); -REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel); -REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel); +REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel); +REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_ARCH_DISPATCH(lu_solve_trans_stub, DEFAULT, &lu_solve_trans_kernel); REGISTER_AVX512_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel); diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index c495fc8307565..e68035eb35f10 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -50,6 +50,15 @@ static inline Tensor cloneBatchedColumnMajor(const Tensor& src) { return result; } +/* + * contig chooses between C-contig (true) and F-contig (false) + */ +static inline c10::MaybeOwned borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) { + return cond ? c10::MaybeOwned::borrowed(borrow) + : c10::MaybeOwned::owned(contig ? clone.clone(MemoryFormat::Contiguous) + : cloneBatchedColumnMajor(clone)); +} + /* * This method is designed to be a faster alternative to * `cloneBatchedColumnMajor` with some additional features, @@ -265,6 +274,11 @@ static inline void singleCheckErrors(int64_t info, const char* name, int64_t bat } else if (strstr(name, "lstsq")) { TORCH_CHECK(false, name, batch_string, ": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ")."); + } else if (strstr(name, "lu_factor")) { + TORCH_CHECK(false, name, batch_string, + ": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. " + "If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or " + "linalg.lu_factor_ex(A, pivot)"); } else { TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, "."); } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp index a1931dd3ec9c8..5388595eee31a 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp @@ -1807,7 +1807,7 @@ REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); For further details, please see the MAGMA documentation for magma_dgetrf_gpu. */ template -static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +static void apply_lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { #if !AT_MAGMA_ENABLED() TORCH_CHECK( false, @@ -1836,6 +1836,7 @@ static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, con int* infos_working_ptr = &infos_data[i]; magmaLu(m, n, input_working_ptr, leading_dimension, pivots_working_ptr, infos_working_ptr); } + // Why can we safely do non_blocking? pivots.copy_(pivots_cpu, /*non_blocking=*/true); } else { for (decltype(batch_size) i = 0; i < batch_size; i++) { @@ -1843,11 +1844,6 @@ static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, con int* infos_working_ptr = &infos_data[i]; magmaLuNoPiv(m, n, input_working_ptr, leading_dimension, infos_working_ptr); } - - // fill the pivots tensor with indices using 1-based (Fortran) indexing - auto k = std::min(m, n); - Tensor pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)).expand_as(pivots); - pivots.copy_(pivots_tmp); } infos.copy_(infos_cpu, /*non_blocking=*/true); #endif @@ -1868,7 +1864,7 @@ static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, con For further details, please see the MAGMA documentation for magma_dgetrf_batched. */ template -static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { #if !AT_MAGMA_ENABLED() TORCH_CHECK( false, @@ -1880,13 +1876,6 @@ static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, co auto input_matrix_stride = matrixStride(input); magma_int_t batch_size = magma_int_cast(batchCount(input), "batchCount"); - // magmaLuBatched doesn't work with zero batch dimensions - // it gives CUDA error: invalid configuration argument - if (batch_size == 0) { - infos.fill_(0); - return; - } - magma_int_t m = magma_int_cast(input.size(-2), "m"); magma_int_t n = magma_int_cast(input.size(-1), "n"); auto leading_dimension = std::max(1, m); @@ -1916,11 +1905,6 @@ static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, co magmaLuBatched(m, n, input_array, leading_dimension, pivots_array, infos_data, batch_size, magma_queue); } else { magmaLuNoPivBatched(m, n, input_array, leading_dimension, infos_data, batch_size, magma_queue); - - // fill the pivots tensor with indices using 1-based (Fortran) indexing - auto k = std::min(m, n); - Tensor pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)).expand_as(pivots); - pivots.copy_(pivots_tmp); } // block CPU until all operations on the queue are finished @@ -1929,38 +1913,50 @@ static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, co #endif } -static void lu_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +static void lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_magma_looped", [&]{ - apply_lu_looped_magma(input, pivots, infos, compute_pivots); + apply_lu_factor_looped_magma(input, pivots, infos, compute_pivots); }); } -static void lu_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { +static void lu_factor_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_magma_batched", [&]{ - apply_lu_batched_magma(input, pivots, infos, compute_pivots); + apply_lu_factor_batched_magma(input, pivots, infos, compute_pivots); }); } -static void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { - int64_t batch_size = batchCount(input); +static void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { + auto batch_size = batchCount(input); + // MAGMA does not work with batch_size == 0. + // CuSolver does not work when the matrices have no elements + if (input.numel() == 0) { + return; + } #ifdef USE_CUSOLVER // Use a heuristic to determine that cusolver is faster than MAGMA for the following sizes. auto m = input.size(-2); // exclude complex128 since nan_to_num_ does not work with it. + // See https://github.com/pytorch/pytorch/issues/59247 for more info if ((batch_size == 1 || (batch_size <= 8 && m <= 16) || !use_magma_ ) && !input.is_complex()) { - lu_looped_cusolver(input, pivots, infos, compute_pivots); + lu_factor_looped_cusolver(input, pivots, infos, compute_pivots); } #else if (batch_size == 1) { - lu_looped_magma(input, pivots, infos, compute_pivots); + lu_factor_looped_magma(input, pivots, infos, compute_pivots); } #endif // USE_CUSOLVER else { - lu_batched_magma(input, pivots, infos, compute_pivots); + lu_factor_batched_magma(input, pivots, infos, compute_pivots); + } + // We return the trivial permutation of pivots starting with 1 (FORTRAN indexing) + if (!compute_pivots) { + auto k = std::min(input.size(-2), input.size(-1)); + auto pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)); + pivots.copy_(pivots_tmp); } } -REGISTER_CUDA_DISPATCH(lu_stub, &apply_lu); +REGISTER_CUDA_DISPATCH(lu_factor_stub, &apply_lu_factor); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp index a41830387c043..2cd4eeafc8cee 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp @@ -1250,50 +1250,34 @@ void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, // The 'apply_' word is used for templated by dtype functions that call an API routine // underneath. Since the cusolver API has a slightly different structure we do not prepend // apply_ to this function. -void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots) { - // Fill the pivots tensor with indices using 1-based (Fortran) indexing. This - // is needed for maintaining the same results with MAGMA. - auto k = std::min(self.size(-2), self.size(-1)); - Tensor pivots_tmp = at::arange(1, k + 1, self.options().dtype(at::kInt)).expand_as(pivots); - pivots.copy_(pivots_tmp); - +void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots) { AT_DISPATCH_FLOATING_TYPES( self.scalar_type(), - "lu_cusolver", + "lu_factor_cusolver", [&self, &pivots, &infos, &get_pivots]() { - int m = cuda_int_cast(self.size(-2), "m"); - int n = cuda_int_cast(self.size(-1), "n"); - int lda = std::max(1, m); - int64_t self_stride = matrixStride(self); - int64_t batch_size = batchCount(self); - scalar_t* self_data = self.data_ptr(); - int* infos_data = infos.data_ptr(); - - auto handle = at::cuda::getCurrentCUDASolverDnHandle(); + const auto m = cuda_int_cast(self.size(-2), "m"); + const auto n = cuda_int_cast(self.size(-1), "n"); + const auto lda = std::max(1, m); + const auto self_stride = matrixStride(self); + const auto batch_size = batchCount(self); + const auto self_data = self.data_ptr(); + const auto infos_data = infos.data_ptr(); + + const auto pivots_data = get_pivots ? pivots.data_ptr() : nullptr; + const auto pivots_stride = get_pivots ? pivots.size(-1) : 0; + + const auto handle = at::cuda::getCurrentCUDASolverDnHandle(); for (auto batch = decltype(batch_size){0}; batch < batch_size; ++batch) { - if (get_pivots) { - auto pivots_data = pivots.data_ptr(); - auto pivots_stride = pivots.size(-1); - at::cuda::solver::getrf( - handle, m, n, - self_data + batch * self_stride, - lda, - pivots_data + batch * pivots_stride, - infos_data + batch - ); - } - else { - at::cuda::solver::getrf( - handle, m, n, - self_data + batch * self_stride, - lda, - nullptr, - infos_data + batch - ); - } + at::cuda::solver::getrf( + handle, m, n, + self_data + batch * self_stride, + lda, + get_pivots ? pivots_data + batch * pivots_stride : nullptr, + infos_data + batch + ); } }); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h index 2c48a0d5d6d0f..8ffb5ff4e9023 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h @@ -61,7 +61,7 @@ Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau); void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors); void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose); -void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots); +void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots); #endif // USE_CUSOLVER diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 64817f9eb3b28..e3e42d24a4a23 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10360,6 +10360,27 @@ python_module: linalg variants: function +# linalg.lu_factor +- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) + python_module: linalg + variants: function + +- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) + python_module: linalg + variants: function + +- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_lu_factor_ex + +- func: linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) + python_module: linalg + variants: function + dispatch: + CPU, CUDA: linalg_lu_factor_ex_out + - func: linalg_det(Tensor self) -> Tensor python_module: linalg variants: function diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 0c4c96f6fe50f..cba8587d88954 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -33,6 +33,7 @@ Decompositions cholesky qr + lu_factor eig eigvals eigh @@ -99,3 +100,4 @@ Experimental Functions cholesky_ex inv_ex + lu_factor_ex diff --git a/test/test_linalg.py b/test/test_linalg.py index 826b2cf1fc69f..d4e58bdfbcbfd 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -5531,6 +5531,67 @@ def test_householder_product_errors_and_warnings(self, device): with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): torch.linalg.householder_product(reflectors, tau) + @precisionOverride({torch.complex64: 5e-6}) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(*floating_and_complex_types()) + def test_linalg_lu_factor(self, device, dtype): + from torch.testing._internal.common_utils import random_matrix + + def run_test(A, pivot, singular): + k = min(A.shape[-2:]) + batch = A.shape[:-2] + if singular: + if pivot: + # We discard the errors, as the factorization always succeeds + LU, pivots, _ = torch.linalg.lu_factor_ex(A, pivot=pivot) + else: + # It may or may not throw as the LU decomposition without pivoting + # may still succeed for sinuglar matrices + try: + LU, pivots = torch.linalg.lu_factor(A, pivot=pivot) + except RuntimeError: + return + else: + LU, pivots = torch.linalg.lu_factor(A, pivot=pivot) + + self.assertEqual(LU.size(), A.shape) + self.assertEqual(pivots.size(), batch + (k,)) + + if not pivot: + self.assertEqual(pivots, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(batch + (k, ))) + + P, L, U = torch.lu_unpack(LU, pivots) + + self.assertEqual(P @ L @ U, A) + + sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0)) + batches = ((0,), (2,), (3,), (1, 0), (3, 5)) + # Non pivoting just implemented for CUDA + pivots = (True, False) if device == "cuda" else (True,) + for ms, batch, pivot, singular in itertools.product(sizes, batches, pivots, (True, False)): + m, n = ms + A = random_matrix(m, n, *batch, singular=singular, dtype=dtype, device=device) + # Just do one of them on singular matrices + if A.numel() == 0 and not singular: + continue + run_test(A, pivot, singular) + + # Reproducer of a magma bug, + # see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on + if dtype == torch.double and singular: + A = torch.ones(batch + ms, dtype=dtype, device=device) + run_test(A, pivot, singular) + + # Info should be positive for rank deficient matrices + A = torch.ones(5, 3, 3, device=device) + self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all()) + + if self.device_type == 'cpu': + # Error checking, no pivoting variant on CPU + with self.assertRaisesRegex(RuntimeError, 'lu_factor without pivoting is not implemented on the CPU'): + torch.lu(torch.empty(1, 2, 2), pivot=False) + @precisionOverride({torch.complex64: 5e-6}) @skipCUDAIfNoMagma @skipCPUIfNoLapack diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 6d920dd317208..7114dd4793a85 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -882,9 +882,10 @@ self: zeros_like(self) other: zeros_like(other) -- name: _lu_with_info(Tensor self, bool pivot=True) -> (Tensor LU, Tensor pivots, Tensor info) - self: _lu_with_info_backward(grad, self, LU, pivots) - LU: _lu_with_info_jvp(self_t, LU, pivots) +- name: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + A: lu_factor_ex_backward(grad, A, LU, pivots) + LU: lu_factor_ex_jvp(A_t, LU, pivots) + output_differentiability: [True, False, False] - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots) diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 508a7fe443a8a..20ee88f8e4345 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -68,6 +68,14 @@ inline Tensor& householder_product_out(Tensor& result, const Tensor& input, cons return torch::linalg_householder_product_out(result, input, tau); } +inline std::tuple lu_factor(const Tensor& self, const bool pivot) { + return torch::linalg_lu_factor(self, pivot); +} + +inline std::tuple lu_factor_out(Tensor& LU, Tensor& pivots, const Tensor& self, const bool pivot) { + return torch::linalg_lu_factor_out(LU, pivots, self, pivot); +} + inline std::tuple lstsq(const Tensor& self, const Tensor& b, c10::optional cond, c10::optional driver) { return torch::linalg_lstsq(self, b, cond, driver); } @@ -333,6 +341,17 @@ inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, c10::string_v return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +/// Computes the pivoted LU factorization +/// +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.lu_factor +inline std::tuple lu_factor(const Tensor& input, const bool pivot=true) { + return detail::lu_factor(input, pivot); +} + +inline std::tuple lu_factor_out(Tensor& LU, Tensor& pivots, const Tensor& self, const bool pivot=true) { + return detail::lu_factor_out(LU, pivots, self, pivot); +} + inline Tensor norm(const Tensor& self, const optional& opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); } diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index f545c8c907573..bca5c6916c24c 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -4040,7 +4040,7 @@ Tensor lu_solve_jvp( // The identity permutation pivots are 1-based because of the Fortran-like LAPACK interfaces. // More details on the permutation matrix canceling note: // as part of forward AD we need to compute A^{-1} dA. - // Since A = P L U and P is not differentiable, we get + // Since A = P L U and P is locally constant for full-rank matrices, we get // dA = P d(L U), A^{-1} = (L U)^{-1} P^T, so // A^{-1} dA = (L U)^{-1} d(L U), which is lu_solve with // the pivots set to the identity permutation @@ -4306,7 +4306,7 @@ Tensor plu_backward_base( return self_grad; } -Tensor _lu_with_info_backward( +Tensor lu_factor_ex_backward( const Tensor& grad, const Tensor& self, const Tensor& LU, @@ -4320,8 +4320,8 @@ Tensor _lu_with_info_backward( return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U); } -Tensor _lu_with_info_jvp( - const Tensor& dX, +Tensor lu_factor_ex_jvp( + const Tensor& dA, const Tensor& LU, const Tensor& pivs ) { @@ -4335,19 +4335,19 @@ Tensor _lu_with_info_jvp( auto n = LU.size(-1); auto k = std::min(m, n); - auto pdX = P.mT().matmul(dX); + auto pdA = P.mT().matmul(dA); // similar to the backward implementation, we also consider block structures such as: // for a matrix A of size m x n we decompose it as // A = (A1 | A2) with A1 of size m x m if m <= n and // A = (A1^T | A2^T)^T with A1 of size n x n if m > n. - auto pdX1 = pdX.narrow(-2, 0, k).narrow(-1, 0, k); + auto pdA1 = pdA.narrow(-2, 0, k).narrow(-1, 0, k); auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k); auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k); - // dK = L1^{-1} pdX1 + // dK = L1^{-1} pdA1 auto dK = std::get<0>(at::triangular_solve( - pdX1, + pdA1, L1, /*upper=*/false, /*transpose=*/false, @@ -4377,11 +4377,11 @@ Tensor _lu_with_info_jvp( if (m < n) { // we only need to update dU2 defined as - // dU2 := L1^{-1} (pdX2 - dL1 U2) - auto pdX2 = pdX.narrow(-1, k, n - k); + // dU2 := L1^{-1} (pdA2 - dL1 U2) + auto pdA2 = pdA.narrow(-1, k, n - k); auto U2 = U.narrow(-1, k, n - k); dLU.narrow(-1, k, n - k).copy_(std::get<0>(at::triangular_solve( - pdX2 - dL1.matmul(U2), + pdA2 - dL1.matmul(U2), L1, /*upper=*/false, /*transpose=*/false, @@ -4390,11 +4390,11 @@ Tensor _lu_with_info_jvp( } else { // we only need to update dL2 defined as - // dL2 := (pdX2 - L2 dU1) U1^{-1} - auto pdX2 = pdX.narrow(-2, k, m - k); + // dL2 := (pdA2 - L2 dU1) U1^{-1} + auto pdA2 = pdA.narrow(-2, k, m - k); auto L2 = L.narrow(-2, k, m - k); dLU.narrow(-2, k, m - k).copy_(std::get<0>(at::triangular_solve( - (pdX2 - L2.matmul(dU1)).mT(), + (pdA2 - L2.matmul(dU1)).mT(), U1, /*upper=*/true, /*transpose=*/true diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 55c601aeacf5c..a7b1eb3a95044 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -321,13 +321,13 @@ Tensor lu_backward_base( const Tensor& L, const Tensor& U ); -Tensor _lu_with_info_backward( +Tensor lu_factor_ex_backward( const Tensor& grad, const Tensor& self, const Tensor& LU, const Tensor& pivs ); -Tensor _lu_with_info_jvp( +Tensor lu_factor_ex_jvp( const Tensor& dX, const Tensor& LU, const Tensor& pivs diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index cc05dd6760db3..200665127b614 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -7,7 +7,10 @@ Tensor = torch.Tensor common_notes = { - "sync_note": """When inputs are on a CUDA device, this function synchronizes that device with the CPU.""" + "experimental_warning": """This function is "experimental" and it may change in a future PyTorch release.""", + "sync_note": "When inputs are on a CUDA device, this function synchronizes that device with the CPU.", + "sync_note_ex": r"When the inputs are on a CUDA device, this function synchronizes only when :attr:`check_errors`\ `= True`.", + "sync_note_has_ex": "When inputs are on a CUDA device, this function synchronizes that device with the CPU. For a version of this function that does not synchronize, see :func:`{}`." } @@ -111,9 +114,11 @@ ``info`` filled with zeros indicates that the decomposition was successful. If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. -.. note:: If :attr:`A` is on a CUDA device, this function may synchronize that device with the CPU. +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} -.. warning:: This function is "experimental" and it may change in a future PyTorch release. +.. warning:: {common_notes["experimental_warning"]} +""" + r""" .. seealso:: :func:`torch.linalg.cholesky` is a NumPy compatible variant that always checks for errors. @@ -240,11 +245,11 @@ Also supports batches of matrices, and if :attr:`A` is a batch of matrices then the output has the same batch dimensions. -.. note:: - If :attr:`A` is on a CUDA device then this function may synchronize - that device with the CPU. +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} -.. warning:: This function is "experimental" and it may change in a future PyTorch release. +.. warning:: {common_notes["experimental_warning"]} +""" + r""" .. seealso:: @@ -1883,6 +1888,117 @@ https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem """) +lu_factor = _add_docstr(_linalg.linalg_lu_factor, r""" +linalg.lu_factor(A, *, bool pivot=True, out=None) -> (Tensor, Tensor) + +Computes a compact representation of the LU factorization with partial pivoting of a matrix. + +This function computes a compact representation of the decomposition given by :func:`torch.linalg.lu`. +If the matrix is square, this representation may be used in :func:`torch.linalg.lu_solve` +to solve system of linear equations that share the matrix :attr:`A`. + +The returned permutation matrix is represented by a 1-indexed vector. `pivots[i] == j` represents +that in the `i`-th step of the algorithm, the `i`-th row was permuted with the `j-1`-th row. + +On CUDA, one may use :attr:`pivot`\ `= False`. In this case, this function returns the LU decomposition without pivoting if it exists. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.lu_factor_ex")} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.lu_solve` solves a system of linear equations given the output of this + function provided the input matrix was square and invertible. + + :func:`torch.linalg.lu` computes the LU decomposition with partial pivoting of a possibly + non-square matrix. + + :func:`torch.linalg.solve` solves a system of linear equations. It can be seen as a composition + of this function and :func:`~lu_solve`. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): [Only on CUDA] Whether to compute the LU decomposition with partial pivoting, or the regular LU decomposition. Default: `True`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots)`. + +Raises: + RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` + is not invertible. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> B1 = torch.randn(2, 3, 4) + >>> B2 = torch.randn(2, 3, 7) + >>> A_factor = torch.linalg.lu_factor(A) + >>> X1 = torch.linalg.lu_solve(A_factor, B1) + >>> X2 = torch.linalg.lu_solve(A_factor, B2) + >>> torch.allclose(A @ X1, B1) + True + >>> torch.allclose(A @ X2, B2) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +lu_factor_ex = _add_docstr(_linalg.linalg_lu_factor_ex, r""" +linalg.lu_factor_ex(A, *, pivot=True, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) + +This is a version of :func:`~lu_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. + +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. This tensor contains integers +denoting the errors that may have happened during the computation of this function. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +.. seealso:: + :func:`~lu_factor` is a SciPy compatible variant that always checks for errors. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): [Only on CUDA] Whether to compute the LU decomposition with partial pivoting, or the regular LU decomposition. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`. + out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots, info)`. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> B1 = torch.randn(2, 3, 4) + >>> B2 = torch.randn(2, 3, 7) + >>> LU, pivots, info = torch.linalg.lu_factor(A) + >>> info + tensor([0, 0, 0], dtype=torch.int32) + >>> X1 = torch.linalg.lu_solve((LU, pivots), B1) + >>> X2 = torch.linalg.lu_solve((LU, pivots), B2) + >>> torch.allclose(A @ X1, B1) + True + >>> torch.allclose(A @ X2, B2) + True + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""") + tensorinv = _add_docstr(_linalg.linalg_tensorinv, r""" linalg.tensorinv(A, ind=2, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index 4133ca6eab975..c5f647bb80124 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -609,6 +609,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.masked_scatter: lambda input, mask, source: -1, torch.masked_select: lambda input, mask, out=None: -1, torch.matmul: lambda input, other, out=None: -1, + torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1, + torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1, torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul torch.matrix_power: lambda input, n: -1, torch.linalg.matrix_power: lambda input, n, out=None: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 42e88401b7bdf..f78d92dd81fd3 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7971,6 +7971,18 @@ def ref_pairwise_distance(input1, input2): dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), supports_autograd=False, sample_inputs_func=sample_inputs_comparison_ops), + OpInfo('linalg.lu_factor', + aten_name='linalg_lu_factor', + op=torch.linalg.lu_factor, + dtypes=floating_and_complex_types(), + supports_inplace_autograd=False, + # we use in-place operations which cannot be avoided. + # This causes vmap failures, hence we skip batched gradient checks + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_lu, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]), OpInfo('lu', op=torch.lu, dtypes=floating_and_complex_types(), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 1b2bc729de596..855c69d4bcff0 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2432,6 +2432,8 @@ def random_matrix(rows, columns, *batch_dims, **kwargs): return torch.ones(rows, columns, dtype=dtype, device=device) A = torch.randn(batch_dims + (rows, columns), dtype=dtype, device=device) + if A.numel() == 0: + return A u, _, vh = torch.linalg.svd(A, full_matrices=False) k = min(rows, columns) s = torch.linspace(1 / (k + 1), 1, k, dtype=dtype, device=device) From 5dd416f480fb234d1a45cfb05c8207b2d6e643f6 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 20 Oct 2021 13:56:23 +0000 Subject: [PATCH 2/5] Update on "Add linalg.lu_factor" This PR exposes `torch.lu` as `torch.linalg.lu_factor` and `torch.linalg.lu_factor_ex`. This PR also adds support for matrices with zero elements both in the size of the matrix and the batch. Note that this function simply returns empty tensors of the correct size in this case. We add a test and an OpInfo for the new function. This PR also adds documentation for this new function in line of the documentation in the rest of `torch.linalg`. [ghstack-poisoned] --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index a80d7132c5a2e..c47bf814e9df1 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1626,7 +1626,7 @@ std::tuple linalg_lu_factor_ex(const Tensor& A, bool piv auto pivots = at::empty({0}, A.options().dtype(kInt)); auto info = at::empty({0}, A.options().dtype(kInt)); at::native::linalg_lu_factor_ex_out(A, pivot, check_errors, LU, pivots, info); - return {std::move(LU), std::move(pivots), std::move(info)}; + return std::make_tuple(std::move(LU), std::move(pivots), std::move(info)); } std::tuple linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor & LU, Tensor & pivots) { @@ -1652,7 +1652,7 @@ std::tuple linalg_lu_factor(const Tensor& A, bool pivot) { singleCheckErrors(info.item(), "torch.linalg.lu_factor"); } - return {std::move(LU), std::move(pivots)}; + return std::make_tuple(std::move(LU), std::move(pivots)); } // TODO Deprecate this function in favour of linalg_lu_factor_ex From 15e815eddb3fefa96b1e7fffe7fd9673ad1b26f7 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 20 Oct 2021 15:12:53 +0000 Subject: [PATCH 3/5] Update on "Add linalg.lu_factor" This PR exposes `torch.lu` as `torch.linalg.lu_factor` and `torch.linalg.lu_factor_ex`. This PR also adds support for matrices with zero elements both in the size of the matrix and the batch. Note that this function simply returns empty tensors of the correct size in this case. We add a test and an OpInfo for the new function. This PR also adds documentation for this new function in line of the documentation in the rest of `torch.linalg`. [ghstack-poisoned] --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 4 ++-- torch/linalg/__init__.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index c47bf814e9df1..41fcf659ca781 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1618,7 +1618,7 @@ std::tuple linalg_lu_factor_ex_out(const Tensor& A, } } - return {LU, pivots, info}; + return std::tie(LU, pivots, info); } std::tuple linalg_lu_factor_ex(const Tensor& A, bool pivot, bool check_errors) { @@ -1639,7 +1639,7 @@ std::tuple linalg_lu_factor_out(const Tensor& A, bool pivot, T singleCheckErrors(info.item(), "torch.linalg.lu_factor"); } - return {LU, pivots}; + return std::tie(LU, pivots); } std::tuple linalg_lu_factor(const Tensor& A, bool pivot) { diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 200665127b614..92c8630a22177 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -10,7 +10,8 @@ "experimental_warning": """This function is "experimental" and it may change in a future PyTorch release.""", "sync_note": "When inputs are on a CUDA device, this function synchronizes that device with the CPU.", "sync_note_ex": r"When the inputs are on a CUDA device, this function synchronizes only when :attr:`check_errors`\ `= True`.", - "sync_note_has_ex": "When inputs are on a CUDA device, this function synchronizes that device with the CPU. For a version of this function that does not synchronize, see :func:`{}`." + "sync_note_has_ex": ("When inputs are on a CUDA device, this function synchronizes that device with the CPU. " + "For a version of this function that does not synchronize, see :func:`{}`.") } @@ -1900,7 +1901,8 @@ The returned permutation matrix is represented by a 1-indexed vector. `pivots[i] == j` represents that in the `i`-th step of the algorithm, the `i`-th row was permuted with the `j-1`-th row. -On CUDA, one may use :attr:`pivot`\ `= False`. In this case, this function returns the LU decomposition without pivoting if it exists. +On CUDA, one may use :attr:`pivot`\ `= False`. In this case, this function returns the LU +decomposition without pivoting if it exists. Supports inputs of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if the inputs are batches of matrices then @@ -1925,7 +1927,8 @@ A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. Keyword args: - pivot (bool, optional): [Only on CUDA] Whether to compute the LU decomposition with partial pivoting, or the regular LU decomposition. Default: `True`. + pivot (bool, optional): [Only on CUDA] Whether to compute the LU decomposition with partial pivoting, + or the regular LU decomposition. Default: `True`. out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. Returns: @@ -1973,7 +1976,8 @@ A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. Keyword args: - pivot (bool, optional): [Only on CUDA] Whether to compute the LU decomposition with partial pivoting, or the regular LU decomposition. Default: `True`. + pivot (bool, optional): [Only on CUDA] Whether to compute the LU decomposition with partial pivoting, + or the regular LU decomposition. Default: `True`. check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`. out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. From b53a22f6097ccb13eb07a3511099b129ba613a50 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 20 Oct 2021 18:36:17 +0000 Subject: [PATCH 4/5] Update on "Add linalg.lu_factor" This PR exposes `torch.lu` as `torch.linalg.lu_factor` and `torch.linalg.lu_factor_ex`. This PR also adds support for matrices with zero elements both in the size of the matrix and the batch. Note that this function simply returns empty tensors of the correct size in this case. We add a test and an OpInfo for the new function. This PR also adds documentation for this new function in line of the documentation in the rest of `torch.linalg`. cc jianyuh nikitaved pearu mruberry walterddr @IvanYashchuk xwang233 @Lezcano [ghstack-poisoned] --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 11 ++++---- tools/autograd/gen_variable_type.py | 2 +- .../_internal/common_methods_invocations.py | 28 ++++++++++++++++++- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 41fcf659ca781..78ca32e8c0e49 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1567,17 +1567,17 @@ std::tuple linalg_lu_factor_ex_out(const Tensor& A, TORCH_CHECK(A.dim() >= 2, "expected tensor with 2 or more dimensions, got size: ", A.sizes(), " instead"); auto req_size = A.sizes().vec(); + const auto m = req_size.cend()[-2]; + const auto n = req_size.cend()[-1]; + // TODO reimplementation of resize_output with format F-contiguous // We should make this a standalone function if (resize_output_check(LU, req_size)) { // Transpose size std::iter_swap(req_size.end() - 1, req_size.end() - 2); LU.resize_(req_size, MemoryFormat::Contiguous); - std::iter_swap(req_size.end() - 1, req_size.end() - 2); - LU.transpose_(-2, -1); // make 'LU' have Fortran contiguous memory layLU + LU.transpose_(-2, -1); // make 'LU' have Fortran contiguous memory } - const auto m = req_size.cend()[-2]; - const auto n = req_size.cend()[-1]; req_size.pop_back(); req_size.back() = std::min(m, n); at::native::resize_output(pivots, req_size); @@ -1597,7 +1597,6 @@ std::tuple linalg_lu_factor_ex_out(const Tensor& A, const auto info_contig = info.is_contiguous(); const auto info_ = borrow_else_clone(info_contig, info, info, /*C-contig*/true); - lu_factor_stub(A.device().type(), *LU_, *pivots_, *info_, pivot); if (!LU_f_contig) { @@ -1632,7 +1631,7 @@ std::tuple linalg_lu_factor_ex(const Tensor& A, bool piv std::tuple linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor & LU, Tensor & pivots) { auto info = at::empty({0}, A.options().dtype(kInt)); // We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors - at::native::linalg_lu_factor_ex_out(A, pivot, /*chech_errors=*/false, LU, pivots, info); + at::linalg_lu_factor_ex_out(LU, pivots, info, A, pivot, /*chech_errors=*/false); if (A.dim() > 2) { batchCheckErrors(info, "torch.linalg.lu_factor"); } else { diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index d00d3ee739f71..3beab52325577 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -105,7 +105,7 @@ 'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward', 'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub', 'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward', - 'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical', + 'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical', 'linalg_lu_factor_ex', 'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid', 'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info', 'linalg_pinv', 'linalg_lstsq', diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f78d92dd81fd3..28e8ebf6963e8 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3427,6 +3427,7 @@ def gen_inputs(): # Empty cases src_sizes = [(0,), (), (1,), (3, 2)] src_gen = (make_arg(size) for size in src_sizes) + idx = make_idx((0,), high=1) for src in src_gen: yield SampleInput(input=src, args=(idx,)) @@ -4197,6 +4198,19 @@ def generate_samples(): return list(generate_samples()) +def sample_inputs_linalg_lu_factor(op_info, device, dtype, requires_grad=False, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + # not needed once OpInfo tests support Iterables + def generate_samples(): + batch_shapes = ((), (3,), (3, 3)) + # pivot=False only supported in CUDA + pivots = (True, False) if dtype == "cuda" else (True,) + deltas = (-2, -1, 0, +1, +2) + for batch_shape, pivot, delta in product(batch_shapes, pivots, deltas): + shape = batch_shape + (S + delta, S) + yield SampleInput(make_arg(shape), kwargs={"pivot": pivot}) + + return list(generate_samples()) def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs): from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value @@ -7981,7 +7995,19 @@ def ref_pairwise_distance(input1, input2): check_batched_grad=False, check_batched_gradgrad=False, supports_forward_ad=True, - sample_inputs_func=sample_inputs_lu, + sample_inputs_func=sample_inputs_linalg_lu_factor, + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]), + OpInfo('linalg.lu_factor_ex', + aten_name='linalg_lu_factor_ex', + op=torch.linalg.lu_factor_ex, + dtypes=floating_and_complex_types(), + supports_inplace_autograd=False, + # we use in-place operations which cannot be avoided. + # This causes vmap failures, hence we skip batched gradient checks + check_batched_grad=False, + check_batched_gradgrad=False, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_linalg_lu_factor, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]), OpInfo('lu', op=torch.lu, From fb16e3cd4f8bf0eb98490e05f6dd2c28b99cafe2 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 20 Oct 2021 18:53:50 +0000 Subject: [PATCH 5/5] Update on "Add linalg.lu_factor" This PR exposes `torch.lu` as `torch.linalg.lu_factor` and `torch.linalg.lu_factor_ex`. This PR also adds support for matrices with zero elements both in the size of the matrix and the batch. Note that this function simply returns empty tensors of the correct size in this case. We add a test and an OpInfo for the new function. This PR also adds documentation for this new function in line of the documentation in the rest of `torch.linalg`. cc jianyuh nikitaved pearu mruberry walterddr @IvanYashchuk xwang233 @Lezcano [ghstack-poisoned] --- torch/testing/_internal/common_methods_invocations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 28e8ebf6963e8..1e46f6ceb47ae 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -4200,6 +4200,7 @@ def generate_samples(): def sample_inputs_linalg_lu_factor(op_info, device, dtype, requires_grad=False, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + # not needed once OpInfo tests support Iterables def generate_samples(): batch_shapes = ((), (3,), (3, 3))