8000 Added {logical_not, trace} refs, moved logical ops to use method overloads by Chillee · Pull Request #79000 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Added {logical_not, trace} refs, moved logical ops to use method overloads #79000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
5ef8502
Added {logical_not, trace} refs, moved logical ops to use method over…
Chillee Jun 7, 2022
286a2df
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 7, 2022
ee82cdd
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 7, 2022
2958830
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 7, 2022
3943353
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 7, 2022
768fdb3
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 7, 2022
c32ffb2
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 8, 2022
fb3d5e3
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 8, 2022
0866438
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 8, 2022
e690812
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 8, 2022
a656b75
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 8, 2022
5ec1f52
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 9, 2022
a5b79c1
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 9, 2022
fd7c60b
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 9, 2022
10000
d692cb6
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 9, 2022
003d457
Update on "Added {logical_not, trace} refs, moved logical ops to use …
Chillee Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ def run_meta_crossref(
torch.mvlgamma: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense, aten::mvlgamma.out
torch.nanmean: {bf16, f16, f32, f64},
torch.nanmedian: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::nanmedian, aten::nanmedian.dim_values
torch.nanquantile: {f32, f64},
torch.nansum: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nansum, aten::nansum.out
torch.nn.functional.conv1d: {bf16, f32, f64, i64},
torch.nn.functional.conv2d: {bf16, f32, f64, i64},
Expand Down Expand Up @@ -493,7 +492,6 @@ def run_meta_crossref(
torch.functional.cdist: {f32, f64},
torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8},
torch.inner: {bf16, f32, f64, i16, i32, i64, i8, u8},
torch.logical_not: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8},
torch.nn.functional.cross_entropy: {bf16, f32, f64},
torch.nn.functional.interpolate: {bf16, f32, f64, u8},
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
Expand Down Expand Up @@ -666,8 +664,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.log_sigmoid_forward.output: {bf16, f64, f32},
aten.logcumsumexp.default: {bf16, f64, f32},
aten.logcumsumexp.out: {bf16, f64, f32},
aten.logical_not.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.logical_not_.default: {bf16, f16, f64, f32},
aten.masked_select.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.masked_select.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
aten.max_pool3d_with_indices.default: {f64, f32},
Expand Down
1 change: 0 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ def _to_tensormeta(x):
return x

# TODO: iterate over requires_grad true/false
inps = tuple(op.reference_inputs(device, dtype, requires_grad=False))
for sample in op.reference_inputs(device, dtype, requires_grad=False):
result = op(sample.input, *sample.args, **sample.kwargs)

Expand Down
10 changes: 0 additions & 10 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,11 +1063,6 @@ def _fused_dropout_decomposition(input, p, generator=None):
return (res, mask)


@register_decomposition(aten.logical_not)
def logical_not(self: Tensor) -> Tensor:
return ~self.to(dtype=torch.bool)


@register_decomposition(aten.xlogy.Tensor)
@pw_cast_for_int_to_real
def xlogy(self: Tensor, other: Tensor) -> Tensor:
Expand Down Expand Up @@ -1224,11 +1219,6 @@ def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
return result.log().add(maxes_squeezed)


@register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor:
return torch.sum(torch.diag(self))


# nb: Should use acc_t, not op_math
@register_decomposition(aten.log_sigmoid_forward)
@out_wrapper_multi('output', 'buffer')
Expand Down
7 changes: 6 additions & 1 deletion torch/_prims/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def torch_to_refs_map():
(torch.nn.functional, torch._refs.nn.functional),
(torch.special, torch._refs.special),
]
r = {}
r: Dict[Any, Any] = {
torch.Tensor.__invert__: torch._refs.bitwise_not,
torch.Tensor.__xor__: torch._refs.bitwise_xor,
torch.Tensor.__and__: torch._refs.bitwise_and,
torch.Tensor.__or__: torch._refs.bitwise_or,
}
for mod_torch, mod_refs in modules:
for s in mod_refs.__all__: # type: ignore[attr-defined]
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
Expand Down
42 changes: 32 additions & 10 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
"square",
"tan",
"tanh",
"trace",
#
# Elementwise Binary References
#
Expand Down Expand Up @@ -119,6 +120,7 @@
# 'ldexp',
"le",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"lt",
Expand Down Expand Up @@ -952,10 +954,10 @@ def isclose(

def _logical_and(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
a = a != 0
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_and(a, b)
b = b != 0
return a & b


logical_and = _make_elementwise_binary_reference(
Expand All @@ -965,12 +967,25 @@ def _logical_and(a: TensorLikeType, b: TensorLikeType):
)


def _logical_not(a: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
return a == 0
return ~a


logical_not = _make_elementwise_unary_reference(
_logical_not,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
aten_op=torch.ops.aten.logical_not,
)


def _logical_or(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
a = a != 0
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_or(a, b)
b = b != 0
return a | b


logical_or = _make_elementwise_binary_reference(
Expand All @@ -982,10 +997,10 @@ def _logical_or(a: TensorLikeType, b: TensorLikeType):

def _logical_xor(a: TensorLikeType, b: TensorLikeType):
if not utils.is_boolean_dtype(a.dtype):
a = ne(a, 0)
a = a != 0
if not utils.is_boolean_dtype(b.dtype):
b = ne(b, 0)
return bitwise_xor(a, b)
b = b != 0
return a ^ b


# TODO: skip unnecessary conversion of long to float
Expand Down Expand Up @@ -2422,6 +2437,13 @@ def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
return item(all(eq(a, b))) # type: ignore[return-value]


# populate the decomp table
@register_decomposition(torch.ops.aten.trace)
def trace(self: TensorLikeType) -> TensorLikeType:
utils.check(
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
)
return torch.sum(torch.diag(self, 0))


import torch._refs.nn.functional
import torch._refs.special
20 changes: 19 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3632,6 +3632,10 @@ def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs):
requires_grad=requires_grad))),)


def error_inputs_trace(op, device):
yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix")


def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
cases = (((S, S, S), (2, 1, 0.5)),
Expand Down Expand Up @@ -4300,7 +4304,6 @@ def error_inputs_embedding(op_info, device, **kwargs):
def error_inputs_t(op_info, device, **kwargs):
yield ErrorInput(
SampleInput(torch.randn(2, 3, 4, 5, device=device)),
error_type=RuntimeError,
error_regex="expects a tensor with <= 2",
)

Expand Down Expand Up @@ -17573,6 +17576,7 @@ def error_inputs_mean(op_info, device, **kwargs):
OpInfo('trace',
dtypes=all_types_and_complex(),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
error_inputs_func=error_inputs_trace,
supports_inplace_autograd=False,
supports_out=False,
supports_forward_ad=True,
Expand Down Expand Up @@ -20410,6 +20414,10 @@ def __init__(
),
)
),
ElementwiseUnaryPythonRefInfo(
"_refs.logical_not",
torch_opinfo_name="logical_not",
),
ElementwiseBinaryPythonRefInfo(
"_refs.logical_or",
torch_opinfo_name="logical_or",
Expand Down Expand Up @@ -20914,6 +20922,16 @@ def __init__(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
),
),
PythonRefInfo(
"_refs.trace",
torch_opinfo_name="trace",
decorators=(
# TODO: torch.diag is currently not supported by either refs, meta funcs, or NVFuser
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
DecorateInfo(unittest.skip("diag is not supported by meta"), 'TestCommon', 'test_python_ref_meta'),
DecorateInfo(unittest.skip("diag is not supported by nvfuser"), 'TestCommon', 'test_python_ref_executor'),
),
),
#
# Tensor Creation Reference OpInfos
#
Expand Down
0