From f9bf9726bce5a16aecc9119572f4eba8ecfa8839 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 13 May 2022 07:55:23 +0000 Subject: [PATCH 1/4] [chalf] enable testing for multiple ops --- .../_internal/common_methods_invocations.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 321a25bbdbe1e4..2e993575fcb0d8 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -13929,8 +13929,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): )), OpInfo('permute', ref=np.transpose, - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), supports_out=False, assert_autodiffed=True, autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused @@ -14233,7 +14232,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), )), OpInfo('split', - dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), sample_inputs_func=partial(sample_inputs_split, list_args=False), supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -14252,7 +14251,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): supports_fwgrad_bwgrad=True, supports_out=False), OpInfo('split_with_sizes', - dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), sample_inputs_func=sample_inputs_split_with_sizes, autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused @@ -14423,7 +14422,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): sample_inputs_func=sample_inputs_add_sub), OpInfo('select', aten_backward_name='select_backward', - dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), sample_inputs_func=sample_inputs_select, assert_jit_shape_analysis=True, supports_forward_ad=True, @@ -15156,14 +15155,14 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)), OpInfo('ravel', ref=np.ravel, - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_ravel, ), OpInfo('reshape', - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), sample_inputs_func=sample_inputs_view_reshape, reference_inputs_func=reference_inputs_reshape, supports_out=False, @@ -15172,7 +15171,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): ), OpInfo('reshape_as', op=lambda x, other: x.reshape_as(other), - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), sample_inputs_func=sample_inputs_view_as_reshape_as, supports_out=False, supports_forward_ad=True, @@ -15936,7 +15935,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): ), OpInfo('unfold', op=lambda x, *args: x.unfold(*args), - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -15979,7 +15978,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), )), OpInfo('squeeze', - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), supports_out=False, assert_autodiffed=True, autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused @@ -16081,7 +16080,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): supports_out=False, sample_inputs_func=sample_cumulative_trapezoid,), OpInfo('unsqueeze', - dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -16186,8 +16185,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): ref=_numpy_ref_transpose, aliases=('swapdims', 'swapaxes'), assert_jit_shape_analysis=True, - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, From 16accbfa4038e053a2c9de1789bb654eff2fad6f Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 13 May 2022 10:03:56 +0000 Subject: [PATCH 2/4] update skips and meta-data --- torch/testing/_internal/common_methods_invocations.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2e993575fcb0d8..202e5b7dd487de 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15936,6 +15936,7 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs): OpInfo('unfold', op=lambda x, *args: x.unfold(*args), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -18465,6 +18466,12 @@ def __init__( PythonRefInfo( "_refs.permute", torch_opinfo_name="permute", + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', + 'test_python_reference_consistency', dtypes=(torch.chalf,), device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', + 'test_python_reference_meta_functions', dtypes=(torch.chalf,), device_type='cuda'), + ), ), PythonRefInfo( "_refs.stack", From 2d2f24a48cca14b0a55ce7cbe17ecf0bfaabb7fa Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 13 May 2022 13:48:20 +0000 Subject: [PATCH 3/4] add comment for skip --- torch/testing/_internal/common_methods_invocations.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 202e5b7dd487de..5605dbff0d046c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -18467,8 +18467,10 @@ def __init__( "_refs.permute", torch_opinfo_name="permute", skips=( + # RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf' DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_consistency', dtypes=(torch.chalf,), device_type='cuda'), + # RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf' DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_reference_meta_functions', dtypes=(torch.chalf,), device_type='cuda'), ), From 4dd4919dd1b49207f72f6d073f2ed059481ba7bd Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 13 May 2022 17:29:24 +0000 Subject: [PATCH 4/4] update skip dtype --- torch/testing/_internal/common_methods_invocations.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2bdae7fa7fc5cd..bf7733efedf9d9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -18559,10 +18559,12 @@ def __init__( skips=( # RuntimeError: "index_select" not implemented for 'ComplexHalf' # RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf' - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_reference_consistency"), + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_reference_consistency", + dtypes=(torch.chalf,)), # RuntimeError: "index_select" not implemented for 'ComplexHalf' # RuntimeError: "index_select_cuda" not implemented for 'ComplexHalf' - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_reference_meta_functions"), + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_reference_meta_functions", + dtypes=(torch.chalf,)), ), ), PythonRefInfo(