8000 Add linalg.lu_solve by lezcano · Pull Request #72935 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add linalg.lu_solve #72935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2ec56a5
Add linalg.lu_solve
lezcano Feb 16, 2022
1a494c2
Update on "Add linalg.lu_solve"
lezcano Mar 4, 2022
203e42a
Update on "Add linalg.lu_solve"
lezcano Mar 4, 2022
21ead93
Update on "Add linalg.lu_solve"
lezcano Mar 4, 2022
f1e9965
Update on "Add linalg.lu_solve"
lezcano Mar 4, 2022
be76bfa
Update on "Add linalg.lu_solve"
lezcano Mar 4, 2022
0eade4d
Update on "Add linalg.lu_solve"
lezcano Mar 7, 2022
e4d8b21
Update on "Add linalg.lu_solve"
lezcano Mar 7, 2022
6d321b2
Update on "Add linalg.lu_solve"
lezcano Mar 8, 2022
dfe5ec7
Update on "Add linalg.lu_solve"
lezcano Mar 8, 2022
2b8a9f8
Update on "Add linalg.lu_solve"
lezcano Mar 10, 2022
62eeb99
Update on "Add linalg.lu_solve"
lezcano Mar 10, 2022
7bb093d
Update on "Add linalg.lu_solve"
lezcano Mar 10, 2022
7a221b1
Update on "Add linalg.lu_solve"
lezcano Mar 10, 2022
b7bdf10
Update on "Add linalg.lu_solve"
lezcano Mar 30, 2022
048d7f5
Update on "Add linalg.lu_solve"
lezcano Mar 31, 2022
6724d50
Update on "Add linalg.lu_solve"
lezcano Mar 31, 2022
20b1d04
Update on "Add linalg.lu_solve"
lezcano Mar 31, 2022
b681131
Update on "Add linalg.lu_solve"
lezcano Mar 31, 2022
ebbebec
Update on "Add linalg.lu_solve"
lezcano Apr 1, 2022
454fc7a
Update on "Add linalg.lu_solve"
lezcano Apr 4, 2022
d0b8d66
Update on "Add linalg.lu_solve"
lezcano Apr 5, 2022
44961a4
Update on "Add linalg.lu_solve"
lezcano Apr 5, 2022
b3efb61
Update on "Add linalg.lu_solve"
lezcano Apr 6, 2022
6a7d7f4
Update on "Add linalg.lu_solve"
lezcano Apr 29, 2022
28a0f45
Update on "Add linalg.lu_solve"
lezcano May 2, 2022
665531c
Update on "Add linalg.lu_solve"
lezcano May 2, 2022
6ae765b
Update on "Add linalg.lu_solve"
lezcano May 4, 2022
3baaf14
Update on "Add linalg.lu_solve"
lezcano May 4, 2022
d9f44ed
Update on "Add linalg.lu_solve"
lezcano May 5, 2022
dd3123a
Update on "Add linalg.lu_solve"
lezcano May 5, 2022
3d4e60e
Update on "Add linalg.lu_solve"
lezcano May 5, 2022
f927ccc
Update on "Add linalg.lu_solve"
lezcano May 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 175 additions & 127 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp

Large diffs are not rendered by default.

15 changes: 4 additions & 11 deletions aten/src/ATen/native/BatchLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,15 @@ DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);

using unpack_pivots_fn = void(*)(
TensorIterator& iter,
const int64_t dim_size
);
const int64_t dim_size);
DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);

using lu_solve_fn = void (*)(
const Tensor& /*b*/,
const Tensor& /*lu*/,
const Tensor& /*pivots*/);
DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);

using lu_solve_trans_fn = void (*)(
const Tensor& /*b*/,
const Tensor& /*lu*/,
const Tensor& /*LU*/,
const Tensor& /*pivots*/,
const Tensor& /*B*/,
TransposeType /*trans*/);
DECLARE_DISPATCH(lu_solve_trans_fn, lu_solve_trans_stub);
DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);

using ldl_factor_fn = void (*)(
const Tensor& /*LD*/,
Expand Down
48 changes: 20 additions & 28 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,34 +1027,34 @@ void lu_factor_kernel(const Tensor& input, const Tensor& pivots, const Tensor& i
For further details, please see the LAPACK documentation for GETRS.
*/
template <typename scalar_t>
void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
void apply_lu_solve(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(
false,
"Calling torch.lu_solve on a CPU tensor requires compiling ",
"Calling linalg.lu_solve on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
auto b_data = b.data_ptr<scalar_t>();
auto lu_data = lu.data_ptr<scalar_t>();
auto b_data = B.data_ptr<scalar_t>();
auto lu_data = LU.data_ptr<scalar_t>();
const auto trans = to_blas(transpose);
auto pivots_data = pivots.data_ptr<int>();
auto b_stride = matrixStride(b);
auto lu_stride = lu.dim() > 2 ? lu.stride(-3) : 0;
auto b_stride = matrixStride(B);
auto lu_stride = LU.dim() > 2 ? LU.stride(-3) : 0;
auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
auto batch_size = batchCount(b);
auto batch_size = batchCount(B);

auto n = lu.size(-2);
auto nrhs = b.size(-1);
auto n = LU.size(-2);
auto nrhs = B.size(-1);
auto leading_dimension = std::max<int64_t>(1, n);

int info = 0;

// lu and pivots tensors can be broadcast to b
// here we construct a helper indexing tensor to linearly index into lu and pivots
IntArrayRef lu_batch_shape(lu.sizes().data(), lu.dim() - 2);
IntArrayRef b_batch_shape(b.sizes().data(), b.dim() - 2);
// lu and pivots tensors can be broadcast to B
// here we construct a helper indexing tensor to linearly index into LU and pivots
IntArrayRef lu_batch_shape(LU.sizes().data(), LU.dim() - 2);
IntArrayRef b_batch_shape(B.sizes().data(), B.dim() - 2);
BroadcastLinearIndices lu_index(
batchCount(lu), lu_batch_shape, b_batch_shape);
batchCount(LU), lu_batch_shape, b_batch_shape);

for (const auto i : c10::irange(batch_size)) {
int64_t lu_index_i = lu_index(i);
Expand All @@ -1073,16 +1073,12 @@ void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, Tra
}

// This is a type dispatching helper function for 'apply_lu_solve'
void lu_solve_trans_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_cpu", [&]{
apply_lu_solve<scalar_t>(b, lu, pivots, trans);
void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "linalg.lu_solve_cpu", [&]{
apply_lu_solve<scalar_t>(LU, pivots, B, trans);
});
}

void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
lu_solve_trans_kernel(b, lu, pivots, TransposeType::NoTranspose);
}

template <typename scalar_t>
static void apply_svd(const Tensor& A,
const bool full_matrices,
Expand Down Expand Up @@ -1162,6 +1158,9 @@ void svd_kernel(const Tensor& A,
}

void unpack_pivots_cpu_kernel(TensorIterator& iter, const int64_t dim_size) {
if (iter.numel() == 0) {
return;
}
auto loop = [&](char* const* const data, const int64_t* const strides, const int64_t nelems) {
auto* perm_ptr = data[0];
const auto* pivots_ptr = data[1];
Expand Down Expand Up @@ -1266,13 +1265,6 @@ REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);

REGISTER_ARCH_DISPATCH(lu_solve_trans_stub, DEFAULT, &lu_solve_trans_kernel);
REGISTER_AVX512_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel);
REGISTER_AVX2_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel);
REGISTER_VSX_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel);
REGISTER_ZVECTOR_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_kernel);

REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel);
REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel);
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/cuda/LinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ static void _launch_kernel(int total_n_elems, func_t f) {
}

void unpack_pivots_cuda_kernel(TensorIterator& iter, const int64_t dim_size) {
if (iter.numel() == 0) {
return;
}

if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
unpack_pivots_cuda_kernel(sub_iter, dim_size);
Expand Down
10 changes: 2 additions & 8 deletions aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,9 @@ void lazy_svd_kernel(const Tensor& A,
svd_stub(DeviceType::CUDA, A, full_matrices, compute_uv, U, S, Vh, info);
}

void lazy_lu_solve_trans(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
void lazy_lu_solve(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
getTorchLinalgLibrary();
lu_solve_trans_stub(DeviceType::CUDA, b, lu, pivots, trans);
}

void lazy_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
getTorchLinalgLibrary();
lu_solve_stub(DeviceType::CUDA, b, lu, pivots);
lu_solve_stub(DeviceType::CUDA, LU, pivots, B, trans);
}

void lazy_lstsq_kernel(const Tensor& a, Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, std::string driver_name) {
Expand Down Expand Up @@ -164,7 +159,6 @@ REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel);
REGISTER_CUDA_DISPATCH(eig_stub, &lazy_eig_kernel);
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel);
REGISTER_CUDA_DISPATCH(svd_stub, &lazy_svd_kernel)
REGISTER_CUDA_DISPATCH(lu_solve_trans_stub, &lazy_lu_solve_trans);
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve);
REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel);
} // anonymous namespace
Expand Down
Loading
0