8000 [chalf] enable testing for multiple ops by kshitij12345 · Pull Request #77405 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[chalf] enable testing for multiple ops #77405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
43 changes: 30 additions & 13 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14161,8 +14161,7 @@ def error_inputs_mean(op_info, device, **kwargs):
)),
OpInfo('permute',
ref=np.transpose,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
assert_autodiffed=True,
autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
Expand Down Expand Up @@ -14479,7 +14478,7 @@ def error_inputs_mean(op_info, device, **kwargs):
dtypes=(torch.chalf,)),
)),
OpInfo('split',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
sample_inputs_func=partial(sample_inputs_split, list_args=False),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand All @@ -14498,7 +14497,7 @@ def error_inputs_mean(op_info, device, **kwargs):
supports_fwgrad_bwgrad=True,
supports_out=False),
OpInfo('split_with_sizes',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
sample_inputs_func=sample_inputs_split_with_sizes,
autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
Expand Down Expand Up @@ -14669,7 +14668,7 @@ def error_inputs_mean(op_info, device, **kwargs):
sample_inputs_func=sample_inputs_add_sub),
OpInfo('select',
aten_backward_name='select_backward',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
sample_inputs_func=sample_inputs_select,
assert_jit_shape_analysis=True,
supports_forward_ad=True,
Expand Down Expand Up @@ -15456,14 +15455,14 @@ def error_inputs_mean(op_info, device, **kwargs):
reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
OpInfo('ravel',
ref=np.ravel,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_ravel,
),
OpInfo('reshape',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
sample_inputs_func=partial(sample_inputs_view_reshape, transpose_samples=True),
reference_inputs_func=partial(reference_inputs_view_reshape, transpose_samples=True),
error_inputs_func=error_inputs_reshape,
Expand All @@ -15473,7 +15472,7 @@ def error_inputs_mean(op_info, device, **kwargs):
),
OpInfo('reshape_as',
op=lambda x, other: x.reshape_as(other),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
sample_inputs_func=sample_inputs_view_as_reshape_as,
supports_out=False,
supports_forward_ad=True,
Expand Down Expand Up @@ -16247,7 +16246,8 @@ def error_inputs_mean(op_info, device, **kwargs):
),
OpInfo('unfold',
op=lambda x, *args: x.unfold(*args),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -16290,7 +16290,7 @@ def error_inputs_mean(op_info, device, **kwargs):
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
)),
OpInfo('squeeze',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
assert_autodiffed=True,
autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
Expand Down Expand Up @@ -16393,7 +16393,7 @@ def error_inputs_mean(op_info, device, **kwargs):
supports_out=False,
sample_inputs_func=sample_cumulative_trapezoid,),
OpInfo('unsqueeze',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -16492,8 +16492,7 @@ def error_inputs_mean(op_info, device, **kwargs):
ref=_numpy_ref_transpose,
aliases=('swapdims', 'swapaxes'),
assert_jit_shape_analysis=True,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down Expand Up @@ -18821,10 +18820,28 @@ def __init__(
PythonRefInfo(
"_refs.permute",
torch_opinfo_name="permute",
skips=(
# RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestCommon',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops forgot to add the error as comment

RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf'

'test_python_reference_consistency', dtypes=(torch.chalf,), device_type='cuda'),
# RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestCommon',
'test_python_reference_meta_functions', dtypes=(torch.chalf,), device_type='cuda'),
),
),
PythonRefInfo(
"_refs.reshape",
torch_opinfo_name="reshape",
skips=(
# RuntimeError: "index_select" not implemented for 'ComplexHalf'
# RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_reference_consistency",
dtypes=(torch.chalf,)),
# RuntimeError: "index_select" not implemented for 'ComplexHalf'
# RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_reference_meta_functions",
dtypes=(torch.chalf,)),
),
),
PythonRefInfo(
"_refs.stack",
Expand Down
0