8000 Deprecate torch.lu by lezcano · Pull Request #77636 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Deprecate torch.lu #77636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,7 @@ bool _requires_fw_or_bw_grad(const Tensor& input) {
// Below of the definitions of the functions operating on a batch that are going to be dispatched
// in the main helper functions for the linear algebra operations

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// Solves a system of linear equations matmul(input, x) = other in-place
// LAPACK/MAGMA error codes are saved in 'infos' tensor, they are not checked here
Expand Down Expand Up @@ -2073,6 +2073,17 @@ std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {

// TODO Deprecate this function in favour of linalg_lu_factor_ex
std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
TORCH_WARN_ONCE(
"torch.lu is deprecated in favor of torch.linalg.lu_factor / torch.linalg.lu_factor_ex and will be ",
"removed in a future PyTorch release.\n",
"LU, pivots = torch.lu(A, compute_pivots)\n",
"should be replaced with\n",
"LU, pivots = torch.linalg.lu_factor(A, compute_pivots)\n",
"and\n",
"LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)\n",
"should be replaced with\n",
"LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)"
);
return at::linalg_lu_factor_ex(self, compute_pivots, false);
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,7 +1920,7 @@ static void apply_lu_factor_batched_magma(const Tensor& input, const Tensor& piv
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(
false,
"Calling torch.lu on a CUDA tensor requires compiling ",
"Calling linalg.lu_factor on a CUDA tensor requires compiling ",
"PyTorch with MAGMA. Please rebuild with MAGMA.");
#else
auto input_data = input.data_ptr<scalar_t>();
Expand Down
6 changes: 3 additions & 3 deletions test/mobile/model_test/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ def blas_lapack_ops(self):
# torch.logdet(m),
# torch.slogdet(m),
# torch.lstsq(m, m),
# torch.lu(m),
# torch.lu_solve(m, *torch.lu(m)),
# torch.lu_unpack(*torch.lu(m)),
# torch.linalg.lu_factor(m),
# torch.lu_solve(m, *torch.linalg.lu_factor(m)),
# torch.lu_unpack(*torch.linalg.lu_factor(m)),
torch.matmul(m, m),
torch.matrix_power(m, 2),
# torch.matrix_rank(m),
Expand Down
14 changes: 1 addition & 13 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9149,20 +9149,8 @@ def istft(input, n_fft):
inps2 = (stft(*inps), inps[1])
self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2))

def lu(x):
# type: (Tensor) -> Tuple[Tensor, Tensor]
return torch.lu(x)

self.checkScript(lu, (torch.randn(2, 3, 3),))

def lu_infos(x):
# type: (Tensor) -> Tuple[Tensor, Tensor, Tensor]
return torch.lu(x, get_infos=True)

self.checkScript(lu_infos, (torch.randn(2, 3, 3),))

def lu_unpack(x):
A_LU, pivots = torch.lu(x)
A_LU, pivots = torch.linalg.lu_factor(x)
return torch.lu_unpack(A_LU, pivots)

for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)):
Expand Down
19 changes: 5 additions & 14 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4213,7 +4213,7 @@ def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_c
size_b = size_b[1:]

if well_conditioned:
PLU = torch.lu_unpack(*torch.lu(make_randn(*size_a)))
PLU = torch.linalg.lu(make_randn(*size_a))
if uni:
# A = L from PLU
A = PLU[1].transpose(-2, -1).contiguous()
Expand Down Expand Up @@ -4900,15 +4900,6 @@ def call_torch_fn(*args, **kwargs):
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True))

if torch._C.has_lapack:
# lu
A_LU, pivots = fn(torch.lu, (0, 5, 5))
self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape])
A_LU, pivots = fn(torch.lu, (0, 0, 0))
self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape])
A_LU, pivots = fn(torch.lu, (2, 0, 0))
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])

@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
@dtypesIfCUDA(*floating_and_complex_types_and(
Expand Down Expand Up @@ -5276,7 +5267,7 @@ def gen_matrices():
@dtypes(torch.double)
def test_lu_unpack_check_input(self, device, dtype):
x = torch.rand(5, 5, 5, device=device, dtype=dtype)
lu_data, lu_pivots = torch.lu(x, pivot=True)
lu_data, lu_pivots = torch.linalg.lu_factor(x)

with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
torch.lu_unpack(lu_data, lu_pivots.long())
Expand Down Expand Up @@ -7163,7 +7154,7 @@ def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):

b = torch.randn(*b_dims, dtype=dtype, device=device)
A = make_A(*A_dims)
LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot)
LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A)
self.assertEqual(info, torch.zeros_like(info))
return b, A, LU_data, LU_pivots

Expand Down Expand Up @@ -7207,7 +7198,7 @@ def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
# Tests tensors with 0 elements
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
A = torch.randn(3, 0, 0, dtype=dtype, device=device)
LU_data, LU_pivots = torch.lu(A)
LU_data, LU_pivots = torch.linalg.lu_factor(A)
self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))

sub_test(True)
Expand Down Expand Up @@ -7242,7 +7233,7 @@ def run_test(A_dims, b_dims, pivot=True):
A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
b = make_tensor(b_dims, dtype=dtype, device=device)
x_exp = np.linalg.solve(A.cpu(), b.cpu())
LU_data, LU_pivots = torch.lu(A, pivot=pivot)
LU_data, LU_pivots = torch.linalg.lu_factor(A)
x = torch.lu_solve(b, LU_data, LU_pivots)
self.assertEqual(x, x_exp)

Expand Down
12 changes: 6 additions & 6 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5803,16 +5803,16 @@ def merge_dicts(*dicts):
lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor

Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
LU factorization of A from :meth:`torch.lu`.
LU factorization of A from :func:`~linalg.lu_factor`.

This function supports ``float``, ``double``, ``cfloat`` and ``cdouble`` dtypes for :attr:`input`.

Arguments:
b (Tensor): the RHS tensor of size :math:`(*, m, k)`, where :math:`*`
is zero or more batch dimensions.
LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu` of size :math:`(*, m, m)`,
LU_data (Tensor): the pivoted LU factorization of A from :meth:`~linalg.lu_factor` of size :math:`(*, m, m)`,
where :math:`*` is zero or more batch dimensions.
LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`torch.lu` of size :math:`(*, m)`,
LU_pivots (IntTensor): the pivots of the LU factorization from :meth:`~linalg.lu_factor` of size :math:`(*, m)`,
where :math:`*` is zero or more batch dimensions.
The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of
:attr:`LU_data`.
Expand All @@ -5824,9 +5824,9 @@ def merge_dicts(*dicts):

>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(2, 3, 1)
>>> A_LU = torch.lu(A)
>>> x = torch.lu_solve(b, *A_LU)
>>> torch.norm(torch.bmm(A, x) - b)
>>> LU, pivots = torch.linalg.lu_factor(A)
>>> x = torch.lu_solve(b, LU, pivots)
>>> torch.dist(A @ x, b)
tensor(1.00000e-07 *
2.8312)
""".format(**common_args))
Expand Down
6 changes: 5 additions & 1 deletion torch/backends/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,14 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend]
* :func:`torch.linalg.cholesky_ex`
* :func:`torch.cholesky_solve`
* :func:`torch.cholesky_inverse`
* :func:`torch.lu`
* :func:`torch.linalg.lu_factor`
* :func:`torch.linalg.lu`
* :func:`torch.linalg.lu_solve`
* :func:`torch.linalg.qr`
* :func:`torch.linalg.eigh`
* :func:`torch.linalg.eighvals`
* :func:`torch.linalg.svd`
* :func:`torch.linalg.svdvals`
'''

if backend is None:
Expand Down
17 changes: 17 additions & 0 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,23 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
``True``.

.. warning::

:func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
future PyTorch release.
``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with

.. code:: python

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with

.. code:: python

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

.. note::
* The returned permutation matrix for every matrix in the batch is
represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
Expand Down
7 changes: 6 additions & 1 deletion torch/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@
Also supports batches of matrices, and if :attr:`A` is a batch of matrices then
the output has the same batch dimensions.

""" + fr"""
.. note:: This function is computed using :func:`torch.linalg.lu_factor`.
{common_notes["sync_note"]}
""" + r"""

.. seealso::

:func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the
Expand Down Expand Up @@ -372,7 +377,7 @@
the output has the same batch dimensions.

""" + fr"""
.. note:: This function is computed using :func:`torch.lu`.
.. note:: This function is computed using :func:`torch.linalg.lu_factor`.
{common_notes["sync_note"]}
""" + r"""

Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6904,7 +6904,7 @@ def sample_inputs_linalg_lu(op_info, device, dtype, requires_grad=False, **kwarg
make_arg = partial(make_fn, dtype=dtype, device=device, requires_grad=requires_grad)

def out_fn(output):
if op_info.name in ("linalg.lu"):
if op_info.name == "linalg.lu":
return output[1], output[2]
else:
return output
Expand Down
0