Description
🐛 Describe the bug
Sometimes when we want dtype-specific typing for Tensor
, e.g. LongTensor
, the conversion from Tensor.long()
(and similar methods) still returns Tensor
, so the following code fails mypy.
from torch import LongTensor, Tensor
def foo(x: Tensor) -> LongTensor:
return x.long()
a.py:5: error: Incompatible return value type (got "Tensor", expected "LongTensor") [return-value]
In fact, it seems no function/method is annotated with LongTensor
, and the only way to obtain LongTensor
for typing is to call the constructor LongTensor()
(and, ofc, typing.cast
).
A simple solution is to change the dtype conversion methods to return corresponding class names.
A more comprehensive solution could be to make Tensor
generic on dtype (similar to numpy), so that methods with dtype=
kwarg can also work. However, this might involve too much work.
Yet I would also like to confirm if there's a specific reason that we shouldn't fix it in either way. Still taking LongTensor
as an example, it appears many times in the docstrings (e.g. argmax
should be another way to create a LongTensor
), but never really used in type annotations, so I wonder if it is designed so on purpose.
Versions
nightly torch-2.7.0.dev20250305