8000 Add linalg.lu_factor by lezcano · Pull Request #66933 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add linalg.lu_factor #66933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 55 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
028a04f
Add linalg.lu_factor
lezcano Oct 20, 2021
5dd416f
Update on "Add linalg.lu_factor"
lezcano Oct 20, 2021
15e815e
Update on "Add linalg.lu_factor"
lezcano Oct 20, 2021
b53a22f
Update on "Add linalg.lu_factor"
lezcano Oct 20, 2021
fb16e3c
Update on "Add linalg.lu_factor"
lezcano Oct 20, 2021
a702fcf
Update on "Add linalg.lu_factor"
lezcano Oct 21, 2021
75f58c3
Update on "Add linalg.lu_factor"
lezcano Oct 21, 2021
57ed735
Update on "Add linalg.lu_factor"
lezcano Oct 21, 2021
beea0a5
Update on "Add linalg.lu_factor"
lezcano Oct 22, 2021
6f3ab47
Update on "Add linalg.lu_factor"
lezcano Oct 22, 2021
b059438
Update on "Add linalg.lu_factor"
lezcano Oct 22, 2021
04e10e0
Update on "Add linalg.lu_factor"
lezcano Oct 22, 2021
5b78366
Update on "Add linalg.lu_factor"
lezcano Oct 29, 2021
7a815e5
Update on "Add linalg.lu_factor"
8000 lezcano Oct 29, 2021
37e1420
Update on "Add linalg.lu_factor"
lezcano Oct 29, 2021
4f0d33b
Update on "Add linalg.lu_factor"
lezcano Oct 31, 2021
c8c133e
Update on "Add linalg.lu_factor"
lezcano Nov 4, 2021
7a98779
Update on "Add linalg.lu_factor"
lezcano Nov 8, 2021
3fd5d73
Update on "Add linalg.lu_factor"
lezcano Nov 11, 2021
fe73990
Update on "Add linalg.lu_factor"
lezcano Nov 12, 2021
67392cc
Update on "Add linalg.lu_factor"
lezcano Nov 15, 2021
38603d8
Update on "Add linalg.lu_factor"
lezcano Nov 17, 2021
81193d4
Update on "Add linalg.lu_factor"
lezcano Nov 18, 2021
dabb4b6
Update on "Add linalg.lu_factor"
lezcano Nov 18, 2021
fdb18d4
Update on "Add linalg.lu_factor"
lezcano Nov 23, 2021
7ab1af3
Update on "Add linalg.lu_factor"
lezcano Nov 23, 2021
662cb75
Update on "Add linalg.lu_factor"
lezcano Nov 24, 2021
96b70f4
Update on "Add linalg.lu_factor"
lezcano Nov 24, 2021
4e605d9
Update on "Add linalg.lu_factor"
lezcano Nov 29, 2021
ee55058
Update on "Add linalg.lu_factor"
lezcano Nov 29, 2021
45c2fab
Update on "Add linalg.lu_factor"
lezcano Nov 29, 2021
4bf1ef6
Update on "Add linalg.lu_factor"
lezcano Nov 29, 2021
bb78a3b
Update on "Add linalg.lu_factor"
lezcano Nov 29, 2021
2ed15de
Update on "Add linalg.lu_factor"
lezcano Nov 30, 2021
3f44be3
Update on "Add linalg.lu_factor"
lezcano Nov 30, 2021
6bb55fc
Update on "Add linalg.lu_factor"
lezcano Nov 30, 2021
d690756
Update on "Add linalg.lu_factor"
lezcano Nov 30, 2021
d85abf3
Update on "Add linalg.lu_factor"
lezcano Dec 1, 2021
e4f5536
Update on "Add linalg.lu_factor"
lezcano Dec 1, 2021
c987c0a
Update on "Add linalg.lu_factor"
lezcano Dec 4, 2021
7168357
Update on "Add linalg.lu_factor"
lezcano Dec 9, 2021
77b3b41
Update on "Add linalg.lu_factor"
lezcano Dec 10, 2021
2b58d06
Update on "Add linalg.lu_factor"
lezcano Dec 14, 2021
38efffe
Update on "Add linalg.lu_factor"
lezcano Dec 14, 2021
6b154ed
Update on "Add linalg.lu_factor"
lezcano Dec 14, 2021
7d0851f
Update on "Add linalg.lu_factor"
lezcano Dec 14, 2021
b9ea49b
Update on "Add linalg.lu_factor"
lezcano Dec 15, 2021
8d76fb8
Update on "Add linalg.lu_factor"
lezcano Dec 19, 2021
77a1d7c
Update on "Add linalg.lu_factor"
lezcano Dec 21, 2021
f1aa99f
Update on "Add linalg.lu_factor"
lezcano Dec 22, 2021
52f4874
Update on "Add linalg.lu_factor"
lezcano Dec 22, 2021
68313e6
Update on "Add linalg.lu_factor"
lezcano Dec 22, 2021
cb5c1b0
Update on "Add linalg.lu_factor"
lezcano Jan 3, 2022
4f7e860
Update on "Add linalg.lu_factor"
lezcano Jan 4, 2022
60e2393
Update on "Add linalg.lu_factor"
lezcano Jan 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 99 additions & 20 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
result = result.unsqueeze_(-1);
}

// lu_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
// lu_factor_stub+lu_solve_stub perform calculations in-place and 'result' must be a copy of 'other_broadcasted'
result.copy_(other_broadcasted);

auto input_working_copy = cloneBatchedColumnMajor(input_broadcasted);
Expand All @@ -945,7 +945,7 @@ static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor
auto pivots_shape = IntArrayRef(input_broadcasted.sizes().data(), input_broadcasted.dim() - 2).vec(); // input_broadcasted.shape[:-2]
pivots_shape.push_back(std::min(input.size(-2), input.size(-1)));
Tensor pivots = at::empty(pivots_shape, input.options().dtype(kInt));
lu_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);
lu_factor_stub(input.device().type(), input_working_copy, pivots, infos, /*compute_pivots=*/true);

// solve the linear system using the LU factorization
lu_solve_stub(input.device().type(), result, input_working_copy, pivots);
Expand Down Expand Up @@ -1571,30 +1571,109 @@ Tensor cholesky_inverse(const Tensor &input, bool upper) {
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

DEFINE_DISPATCH(lu_stub);
DEFINE_DISPATCH(lu_factor_stub);

// TODO: remove check_errors argument
// https://github.com/pytorch/pytorch/issues/64014
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool check_errors) {
TORCH_CHECK(self.dim() >= 2,
"expected tensor with 2 or more dimensions, got size: ", self.sizes(),
" instead");
auto m = self.size(-2);
auto n = self.size(-1);
auto req_size = self.sizes().vec();
std::tuple<Tensor&, Tensor&, Tensor&> linalg_lu_factor_ex_out(const Tensor& A,
bool pivot,
bool check_errors,
Tensor& LU,
Tensor& pivots,
Tensor& info) {
TORCH_CHECK(A.dim() >= 2,
"expected tensor with 2 or more dimensions, got size: ", A.sizes(), " instead");
auto req_size = A.sizes().vec();
const auto m = req_size.cend()[-2];
const auto n = req_size.cend()[-1];

// TODO reimplementation of resize_output with format F-contiguous
// We should make this a standalone function
if (resize_output_check(LU, req_size)) {
// Transpose size
std::iter_swap(req_size.end() - 1, req_size.end() - 2);
LU.resize_(req_size, MemoryFormat::Contiguous);
LU.transpose_(-2, -1); // make 'LU' have Fortran contiguous memory
}
req_size.pop_back();
req_size.back() = std::min(m, n);
auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt));
at::native::resize_output(pivots, req_size);
req_size.pop_back();
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
at::native::resize_output(info, req_size);

const auto LU_f_contig = LU.transpose(-2, -1).is_contiguous() ;

if (LU_f_contig && !LU.is_same(A)) {
LU.copy_(A);
}
const auto LU_ = borrow_else_clone(LU_f_contig, LU, A, /*C-contig*/false);

const auto pivots_contig = pivots.is_contiguous();
const auto pivots_ = borrow_else_clone(pivots_contig, pivots, pivots, /*C-contig*/true);

const auto info_contig = info.is_contiguous();
const auto info_ = borrow_else_clone(info_contig, info, info, /*C-contig*/true);

lu_factor_stub(A.device().type(), *LU_, *pivots_, *info_, pivot);

if (!LU_f_contig) {
LU.copy_(*LU_);
}
if (!pivots_contig) {
pivots.copy_(*pivots_);
}
if (!info_contig) {
info.copy_(*info_);
}

if (check_errors) {
if (A.dim() > 2) {
batchCheckErrors(info, "torch.linalg.lu_factor_ex");
} else {
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor_ex");
}
}

return std::tie(LU, pivots, info);
}

std::tuple<Tensor, Tensor, Tensor> linalg_lu_factor_ex(const Tensor& A, bool pivot, bool check_errors) {
auto LU = at::empty({0}, A.options());
auto pivots = at::empty({0}, A.options().dtype(kInt));
auto info = at::empty({0}, A.options().dtype(kInt));
at::native::linalg_lu_factor_ex_out(A, pivot, check_errors, LU, pivots, info);
return std::make_tuple(std::move(LU), std::move(pivots), std::move(info));
}

std::tuple<Tensor&, Tensor&> linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor & LU, Tensor & pivots) {
auto info = at::empty({0}, A.options().dtype(kInt));
// We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors
at::linalg_lu_factor_ex_out(LU, pivots, info, A, pivot, /*chech_errors=*/false);
if (A.dim() > 2) {
batchCheckErrors(info, "torch.linalg.lu_factor");
} else {
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor");
}

return std::tie(LU, pivots);
}

std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {
Tensor LU, pivots, info;
std::tie(LU, pivots, info) = at::linalg_lu_factor_ex(A, pivot, /*check_errors=*/false);

if (A.dim() > 2) {
batchCheckErrors(info, "torch.linalg.lu_factor");
} else {
singleCheckErrors(info.item<int64_t>(), "torch.linalg.lu_factor");
}

return std::make_tuple(std::move(LU), std::move(pivots));
}

// lu_stub (apply_lu) requires batched column major (Fortran-contiguous) tensors
// 'lu' tensor is modified in-place and must be a copy of 'self'
Tensor lu = cloneBatchedColumnMajor(self);
lu_stub(self.device().type(), lu, pivots_tensor, infos_tensor, compute_pivots);
return std::make_tuple(lu, pivots_tensor, infos_tensor);
// TODO Deprecate this function in favour of linalg_lu_factor_ex
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
return at::linalg_lu_factor_ex(self, compute_pivots, false);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ using triangular_solve_fn = void (*)(
bool /*unitriangular*/);
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);

using lu_fn = void (*)(
using lu_factor_fn = void (*)(
const Tensor& /*input*/,
const Tensor& /*pivots*/,
const Tensor& /*infos*/,
bool /*compute_pivots*/);
DECLARE_DISPATCH(lu_fn, lu_stub);
DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);

using lu_solve_fn = void (*)(
const Tensor& /*b*/,
Expand Down
22 changes: 11 additions & 11 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,14 +847,14 @@ void triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool u
For further details, please see the LAPACK documentation for GETRF.
*/
template <typename scalar_t>
void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(
false,
"Calling torch.lu on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
TORCH_CHECK(compute_pivots, "lu without pivoting is not implemented on the CPU");
TORCH_CHECK(compute_pivots, "linalg.lu_factor: LU without pivoting is not implemented on the CPU");

auto input_data = input.data_ptr<scalar_t>();
auto pivots_data = pivots.data_ptr<int>();
Expand All @@ -876,9 +876,9 @@ void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bo
}

// This is a type dispatching helper function for 'apply_lu'
void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
void lu_factor_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_cpu", [&]{
apply_lu<scalar_t>(input, pivots, infos, compute_pivots);
apply_lu_factor<scalar_t>(input, pivots, infos, compute_pivots);
});
}

Expand All @@ -890,8 +890,8 @@ void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, b
Args:
* `b` - [in] the right hand side matrix B
[out] the solution matrix X
* `lu` - [in] the LU factorization of matrix A (see at::_lu_with_info)
* `pivots` - [in] the pivot indices (see at::_lu_with_info)
* `lu` - [in] the LU factorization of matrix A (see at::linalg_lu_factor)
* `pivots` - [in] the pivot indices (see at::linalg_lu_factor)

For further details, please see the LAPACK documentation for GETRS.
*/
Expand Down Expand Up @@ -1005,11 +1005,11 @@ REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);

REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel);
REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel);
REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel);
REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel);
REGISTER_ZVECTOR_DISPATCH(lu_stub, &lu_kernel);
REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel);
REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel);
REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel);

REGISTER_ARCH_DISPATCH(lu_solve_trans_stub, DEFAULT, &lu_solve_trans_kernel);
REGISTER_AVX512_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ DEFINE_DISPATCH(linalg_vector_norm_stub);
// where info helps us identify singular matrices.
static inline std::tuple<c10::ExclusivelyOwned<Tensor>, c10::ExclusivelyOwned<Tensor>> _lu_det_P_diag_U(const Tensor& self) {
Tensor pivs, lu, infos;
std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
std::tie(lu, pivs, infos) = at::linalg_lu_factor_ex(self);
TORCH_CHECK(infos.ge(0).all().item<uint8_t>(), "Invalid argument passed to lu");
auto n = self.size(-1);
auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs)
Expand All @@ -135,7 +135,7 @@ static inline std::tuple<c10::ExclusivelyOwned<Tensor>, c10::ExclusivelyOwned<Te
// det(A) = ([is P odd] * -2 + 1) * prod(diag(U))
std::tuple<Tensor, Tensor, Tensor> _det_lu_based_helper(const Tensor& self) {
Tensor lu, pivs, infos;
std::tie(lu, pivs, infos) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors*/false);
std::tie(lu, pivs, infos) = at::linalg_lu_factor_ex(self);
TORCH_CHECK(infos.ge(0).all().item<uint8_t>(), "at::_det_lu_based_helper(): Invalid argument passed to LU");

// find det(P)
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
return result;
}

/*
* contig chooses between C-contig (true) and F-contig (false)
*/
static inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
: c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
: cloneBatchedColumnMajor(clone));
}

/*
* This method is designed to be a faster alternative to
* `cloneBatchedColumnMajor` with some additional features,
Expand Down Expand Up @@ -280,6 +289,11 @@ static inline void singleCheckErrors(int64_t info, const char* name, int64_t bat
} else if (strstr(name, "lstsq")) {
TORCH_CHECK_LINALG(false, name, batch_string,
": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
} else if (strstr(name, "lu_factor")) {
TORCH_CHECK(false, name, batch_string,
": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. "
"If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or "
"linalg.lu_factor_ex(A, pivot)");
} else {
TORCH_INTERNAL_ASSERT(false, name, ": Unknown error code: ", info, ".");
}
Expand Down
Loading
0