-
Notifications
You must be signed in to change notification settings - Fork 24.5k
Replace get_all_
type macros with the ATen dispatch macros.
#71561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow For more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 978c55c (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
If you did this via a |
I looked through most of this PR and it seems like a OK "better engineering" code base consistency improvement, but it must be extremely careful to map each existing dtype enumeration to the new style. @pmeier would you also take a look? If you're looking for Python-based improvements like this, @khushi-411, then you might want to look at the Line 1763 in f5a71ec
Line 2552 in f5a71ec
Line 3904 in f5a71ec
|
Thank you so much, @mruberry, for reviewing the PR! I will update the code soon and re-check it carefully :)
Thank you, for sharing the issue! :) |
I was wondering if this PR should be split so that each PR covers 3-4 files at a time? wdyt @mruberry ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @khushi-411 for the patience!
Hi, @mruberry! |
@@ -2271,14 +2271,22 @@ def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwarg | |||
|
|||
sample_inputs = [] | |||
for input_args, broadcasts_input in test_cases: | |||
args = tuple(make_tensor(arg, dtype=dtype, device=device, requires_grad=requires_grad) if isinstance(arg, tuple) else arg | |||
# addcdiv should accept inputs with zero value | |||
# Currently, it throws ZeroDivisionError when the denominator is zero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! And thanks for reviewing, @pmeier.
Let's make sure the extended tests pass and we'll merge this
@mruberry has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hi, @pmeier! Hi, @mruberry! Hi, @kshitij12345! |
@mruberry has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: Hi, Team! The PR is motivated from #71153 (comment). It aims to replace `get_all` type macros with the ATen dispatch macros. The files it iterates over are: (Thanks, Lezcano, for the idea!!) <details> <summary> `test/test_autograd.py`</summary> <p> ```python 43:from torch.testing._internal.common_dtype import get_all_dtypes 8506: floating_dt = [dt for dt in get_all_dtypes() if dt.is_floating_point] ``` </p> </details> <details> <summary> `test/test_binary_ufuncs.py`</summary> <p> ```python 26: all_types_and_complex_and, integral_types_and, get_all_dtypes, get_all_int_dtypes, get_all_math_dtypes, 27: get_all_complex_dtypes, get_all_fp_dtypes, 935: dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) 1035: dtypes(*get_all_dtypes( 1488: dtypes(*(get_all_dtypes(include_bool=False, include_bfloat16=False))) 1879: dtypes(*product(get_all_dtypes(include_complex=False), get_all_dtypes(include_complex=False))) 1887: dtypes(*(get_all_int_dtypes() + [torch.bool])) 1913: dtypes(*(get_all_fp_dtypes())) 1941: dtypes(*(get_all_fp_dtypes())) 1977: dtypes(*product(get_all_complex_dtypes(), get_all_dtypes())) 2019: dtypes(*product(get_all_fp_dtypes(), get_all_fp_dtypes())) 2048: dtypes(*get_all_dtypes()) 2110: dtypes(*product(get_all_dtypes(include_complex=False), 2111: get_all_dtypes(include_complex=False))) 2128: types = [torch.bool, torch.bfloat16] + get_all_int_dtypes() 2173: if dtypes[1] in get_all_fp_dtypes(): 2178: dtypes(*product(get_all_fp_dtypes(), 2179: get_all_fp_dtypes())) 2260: dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) 2261: dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) 2273: dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128}) 2274: dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128}) 2307: dtypes(*get_all_math_dtypes('cpu')) 2319: dtypes(*get_all_fp_dtypes(include_bfloat16=False)) 2331: dtypes(*get_all_int_dtypes()) 2356: dtypes(*get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False)) 2393: if dtype in get_all_int_dtypes(): 2614: dtypes(*get_all_dtypes()) 2624: dtypes(*tuple(itertools.combinations_with_replacement(get_all_dtypes(), 2))) 2806: dtypes(*list(product(get_all_dtypes(include_complex=False), 2807: get_all_dtypes(include_complex=False)))) 2866: dtypes(*list(product(get_all_complex_dtypes(), 2867: get_all_complex_dtypes()))) 2902: dtypes(*product(get_all_dtypes(), get_all_dtypes())) 2906: dtypes(*product(get_all_dtypes(), get_all_dtypes())) 2910: dtypes(*product(get_all_dtypes(), get_all_dtypes())) 3019: dtypes = [torch.float, torch.double] + get_all_complex_dtypes() 3221: dtypes(*get_all_dtypes(include_complex=False)) 3407: dtypes(*list(product(get_all_dtypes(include_bool=False), 3408: get_all_dtypes(include_bool=False)))) 3504: dtypes(*product(get_all_dtypes(include_complex=False, include_bfloat16=False), 3505: get_all_dtypes(include_complex=False, include_bfloat16=False))) 3516: if x.dtype in get_all_int_dtypes() + [torch.bool]: 3643: dtypes(*product(get_all_dtypes(include_complex=False, 3645: get_all_dtypes(include_complex=False, ``` </p> </details> <details> <summary> `test/test_complex.py`</summary> <p> ```python 6:from torch.testing._internal.common_dtype import get_all_complex_dtypes 11: dtypes(*get_all_complex_dtypes()) ``` </p> </details> <details> <summary> `test/test_foreach.py`</summary> <p> ```python 18: get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes, 142: if dtype in get_all_int_dtypes(): 179: disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] 201: disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] 205: disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool] 211: disable_fastpath |= dtype not in get_all_complex_dtypes() 241: bool_int_div = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool] 246: disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool] 248: disable_fastpath |= dtype not in get_all_complex_dtypes() 250: disable_fastpath |= True and dtype not in get_all_complex_dtypes() 307: disable_fastpath = dtype in get_all_int_dtypes() + [torch.bool] 365: if opinfo.name == "_foreach_abs" and dtype in get_all_complex_dtypes(): 376: ops(foreach_unary_op_db, dtypes=get_all_dtypes()) 393: dtypes=get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) 401: ops(foreach_minmax_op_db, dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True)) 426: if ord in (1, 2) and dtype in torch.testing.get_all_fp_dtypes(): 439: dtypes(*get_all_dtypes()) 449: ops(foreach_binary_op_db, dtypes=get_all_dtypes()) 481: ops(foreach_binary_op_db, dtypes=get_all_dtypes()) 536: if dtype in get_all_int_dtypes() + [torch.bool] and foreach_op == torch._foreach_div: 545: ops(foreach_binary_op_db, dtypes=get_all_dtypes()) 637: ops(foreach_pointwise_op_db, allowed_dtypes=get_all_fp_dtypes(include_half=False, include_bfloat16=False)) ``` </p> </details> <details> <summary> `test/test_linalg.py`</summary> <p> ```python 29: all_types, floating_types, floating_and_complex_types, get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, 30: get_all_fp_dtypes, 111: dtypes(*(get_all_dtypes())) 794: float_and_complex_dtypes = get_all_fp_dtypes() + get_all_complex_dtypes() 807: dtypes(*(get_all_int_dtypes())) 828: dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) 841: if dtype in get_all_complex_dtypes(): 844: dtypes(*itertools.product(get_all_dtypes(), 845: get_all_dtypes())) 855: for dtypes0, dtypes1, dtypes2 in product(get_all_dtypes(), repeat=3): 5607: *get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater))) 5608: dtypes(*(set(get_all_dtypes()) - {torch.half, torch.bool})) 5644: dtypes(*(get_all_complex_dtypes() + get_all_fp_dtypes())) 6255: dtypesIfCUDA(*get_all_complex_dtypes(), 6256: *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)), 6292: dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) 6323: dtypesIfCUDA(*get_all_complex_dtypes(), 6324: *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) 6325: dtypes(*get_all_complex_dtypes(), *get_all_fp_dtypes()) 6358: dtypesIfCUDA(*([torch.float, torch.double] + get_all_complex_dtypes())) 6556: dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) 6668: dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) 6741: dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes()) ``` </p> </details> <details> <summary> `test/test_nn.py`</summary> <p> ```python 37:from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes 50: onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types, \ 8862: for device in get_all_device_types(): 9629: for dt1 in get_all_math_dtypes(device): 9630: for dt2 in get_all_math_dtypes(device): 9631: for dt3 in get_all_math_dtypes(device): 9648: for input_dtype in get_all_math_dtypes(device): 9664: for input_dtype in get_all_math_dtypes(device): 13015: dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) 13034: dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) 13159: dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) 17400: dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) 17768: dtypesIfCUDA(*get_all_fp_dtypes()) 17773: dtypesIfCUDA(*get_all_fp_dtypes()) 17778: dtypesIfCUDA(*get_all_fp_dtypes()) 17783: dtypesIfCUDA(*get_all_fp_dtypes()) 17788: dtypesIfCUDA(*get_all_fp_dtypes()) 17793: dtypesIfCUDA(*get_all_fp_dtypes()) 17798: dtypesIfCUDA(*get_all_fp_dtypes()) 17963: dtypesIfCUDA(*get_all_fp_dtypes()) 17977: dtypesIfCUDA(*get_all_fp_dtypes()) 18684: def test_cross_entropy_loss_prob_target_all_reductions(self, device): ``` </p> </details> <details> <summary> `test/test_numpy_interop.py`</summary> <p> ```python 12:from torch.testing._internal.common_dtype import get_all_dtypes 399: dtypes(*get_all_dtypes()) ``` </p> </details> <details> <summary> `test/test_ops.py`</summary> <p> ```python 12:from torch.testing._internal.common_dtype import floating_and_complex_types_and, get_all_dtypes 86: for dtype in get_all_dtypes(): ``` </p> </details> <details> <summary> `test/test_reductions.py`</summary> <p> ```python 16: get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes, 360: allowed_dtypes=get_all_dtypes(include_bfloat16=False)) 366: allowed_dtypes=get_all_dtypes(include_bfloat16=False)) 394: allowed_dtypes=get_all_dtypes(include_bfloat16=False)) 750: for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]: 1404: dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) 1457: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1458: get_all_complex_dtypes())) 1465: return dtype in get_all_int_dtypes() 1494: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) 1501: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) 1507: dtypes(*(get_all_complex_dtypes())) 1514: dtypes = list(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)) 1523: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))) 1531: if dtype in get_all_fp_dtypes(): 1608: dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False, 1837: dtypes(*get_all_dtypes(include_bool=False, include_complex=False)) 1855: dtypes(*(set(get_all_dtypes(include_bool=False, include_complex=False)) - {torch.uint8})) 3219: for dtype in get_all_dtypes(include_half=True, include_bfloat16=False, ``` </p> </details> <details> <summary> `test/test_serialization.py`</summary> <p> ```python 26:from torch.testing._internal.common_dtype import get_all_dtypes 586: for device, dtype in product(devices, get_all_dtypes()): 589: for other_dtype in get_all_dtypes(): ``` </p> </details> <details> <summary> `test/test_shape_ops.py`</summary> <p> ```python 18:from torch.testing._internal.common_dtype import get_all_dtypes 230: dtypes(*get_all_dtypes(include_complex=False, include_bool=False, include_half=False, 232: dtypesIfCUDA(*get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False)) 344: dtypes(*get_all_dtypes()) 443: dtypes(*get_all_dtypes()) 461: dtypes(*get_all_dtypes()) 570: dtypes(*get_all_dtypes(include_complex=False)) ``` </p> </details> <details> <summary> `test/test_sort_and_select.py`</summary> <p> ```python 12: all_types, all_types_and, floating_types_and, get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, 136: dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128}) 231: dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128}) 296: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 647: dtypesIfCUDA(*get_all_fp_dtypes()) 678: dtypesIfCUDA(*(get_all_dtypes(include_complex=False, 682: dtypes(*(get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False))) 739: dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128}) 740: dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) 799: dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128}) 800: dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128}) ``` </p> </details> <details> <summary> `test/test_sparse.py`</summary> <p> ```python 20:from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes 29: floating_and_complex_types, floating_and_complex_types_and, get_all_dtypes, get_all_int_dtypes, 1963: return dtype in get_all_int_dtypes() 1994: dtypes(*get_all_dtypes(include_bool=False, include_half=False, 2103: return dtype in get_all_int_dtypes() 2138: dtypes(*get_all_dtypes(include_bool=False, include_half=False, 2626: all_sparse_dtypes = get_all_dtypes(include_complex=True) 2633: all_sparse_dtypes = get_all_dtypes(include_complex=True) 3230: dtypes(*get_all_complex_dtypes(), 3231: *get_all_fp_dtypes(include_half=False, include_bfloat16=False)) 3234: *get_all_fp_dtypes( ``` </p> </details> <details> <summary> `test/test_sparse_csr.py`</summary> <p> ```python 7:from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes, floating_and_complex_types, make_tensor 17:from torch.testing._internal.common_dtype import floating_types, get_all_dtypes 120: dtypes(*get_all_dtypes()) 133: dtypes(*get_all_dtypes()) 150: dtypes(*get_all_dtypes()) 180: dtypes(*get_all_dtypes()) 201: dtypes(*get_all_dtypes()) 210: dtypes(*get_all_dtypes()) 225: dtypes(*get_all_dtypes()) 244: dtypes(*get_all_dtypes()) 263: dtypes(*get_all_dtypes()) 285: dtypes(*get_all_dtypes()) 411: dtypes(*get_all_dtypes()) 482: dtypes(*get_all_dtypes()) 502: dtypes(*get_all_dtypes()) 562: dtypes(*get_all_dtypes()) 588: dtypesIfCUDA(*get_all_complex_dtypes(), 589: *get_all_fp_dtypes(include_half=SM53OrLater, include_bfloat16=SM80OrLater)) 745: dtypesIfCUDA(*get_all_complex_dtypes(), 746: *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC, 765: dtypesIfCUDA(*get_all_complex_dtypes(), 766: *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC, 801: *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater, 841: *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater, 1182: dtypes(*get_all_dtypes()) 1276: dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_bfloat16=False)) 1286: dtypes(*get_all_dtypes()) ``` </p> </details> <details> <summary> `test/test_tensor_creation_ops.py`</summary> <p> ```python 21: onlyCUDA, skipCPUIf, dtypesIfCUDA, skipMeta, get_all_device_types) 23: get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes 150: for dt in get_all_dtypes(): 160: for dt in get_all_dtypes(): 314: dtypes = [dtype for dtype in get_all_dtypes() if dtype != torch.bfloat16] 1012: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1013: get_all_complex_dtypes())) 1032: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1033: get_all_complex_dtypes())) 1050: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1051: get_all_complex_dtypes())) 1745: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 1779: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 1868: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 1926: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 1954: do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device) 1956: do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, None) 1957: do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device) 2538: for device in get_all_device_types(): 2645: for dtype in get_all_dtypes(): 2678: dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False) + 2679: get_all_complex_dtypes())) 2716: dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False)) 2827: for dt in get_all_dtypes(): 2913: dtypes(*get_all_dtypes(include_bool=False, include_half=False)) 2914: dtypesIfCUDA(*get_all_dtypes(include_bool=False, include_half=True)) 3028: dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) 3033: dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes())) 3074: dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_complex=False)) 3075: dtypesIfCUDA(*((get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16]) 3077: else get_all_dtypes(include_bool=False, include_half=True, include_complex=False))) 3873: dtypes(*get_all_dtypes()) 3884: dtypes(*get_all_dtypes(include_bool=False)) 3916: for other in get_all_dtypes(): 3922: dtypes(*get_all_dtypes()) 3932: dtypes(*get_all_dtypes(include_bool=False)) 3955: dtypes(*get_all_dtypes(include_bool=False)) 3961: dtypes(*get_all_dtypes(include_bool=False)) 3965: dtypes(*get_all_dtypes()) ``` </p> </details> <details> <summary> `test/test_testing.py`</summary> <p> ```python 25:from torch.testing._internal.common_dtype import get_all_dtypes 31: dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False, ``` </p> </details> <details> <summary> `test/test_torch.py`</summary> <p> ```python 51: expectedAlertNondeterministic, get_all_device_types, skipXLA) 57: get_all_fp_dtypes, get_all_int_dtypes, get_all_math_dtypes, get_all_dtypes, get_all_complex_dtypes 296: for d in get_all_device_types(): 323: for device in get_all_device_types(): 324: for dt1 in get_all_dtypes(): 325: for dt2 in get_all_dtypes(): 343: all_dtypes = get_all_dtypes() 350: all_dtypes = get_all_dtypes() 781: for dtype in get_all_dtypes(): 986: for device in get_all_device_types(): 1017: for device in get_all_device_types(): 1018: for dtype in get_all_math_dtypes(device): 2792: for device in get_all_device_types(): 3186: dtypes(*get_all_dtypes()) 3195: for error_dtype in get_all_dtypes(): 3203: dtypes(*get_all_dtypes()) 3212: for error_dtype in get_all_dtypes(): 4539: dtypes(*get_all_fp_dtypes()) 4545: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 4577: dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False)) 4578: dtypesIfCPU(*(get_all_fp_dtypes(include_half=False, include_bfloat16=True))) 4579: dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False))) 4599: dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False))) 4600: dtypesIfCPU(*(get_all_dtypes(include_half=False, include_bfloat16=False, include_complex=False))) 4601: dtypesIfCUDA(*(get_all_dtypes(include_bfloat16=False, include_complex=False))) 4613: for p_dtype in get_all_fp_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False): 4628: dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False))) 4629: dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False))) 4640: dtypes(*get_all_fp_dtypes()) 4723: dtypes(*get_all_fp_dtypes()) 4735: dtypes(*get_all_fp_dtypes(include_bfloat16=False)) 4736: dtypesIfCUDA(*get_all_fp_dtypes()) 4747: dtypes(*get_all_fp_dtypes()) 4761: dtypes(*get_all_fp_dtypes()) 4771: dtypes(*get_all_fp_dtypes()) 4792: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 5302: dtypes(*get_all_dtypes(include_bfloat16=False)) 5322: dtypes(*get_all_dtypes(include_half=False, include_bfloat16=False)) 5323: dtypesIfCPU(*get_all_dtypes(include_bfloat16=False)) 5324: dtypesIfCUDA(*get_all_dtypes(include_bfloat16=False)) 5591: for dt in get_all_dtypes(): 5611: for dt in get_all_dtypes(): 5678: for dt in get_all_dtypes(): 5696: dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) 5697: dtypes(*set(get_all_math_dtypes('cpu'))) 5746: dtypes(*get_all_dtypes()) 5780: dtypes(*get_all_dtypes()) 5885: dtypes(*get_all_dtypes()) 5902: dtypes(*get_all_dtypes()) 5945: dtypes(*get_all_dtypes()) 5979: dtypes(*get_all_dtypes(include_bool=False)) 6049: dtypes(*get_all_dtypes(include_bool=False)) 6092: dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + 6093: get_all_complex_dtypes())) 6094: dtypesIfCPU(*get_all_dtypes()) 6095: dtypesIfCUDA(*get_all_dtypes()) 6122: dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + 6123: get_all_complex_dtypes())) 6124: dtypesIfCPU(*get_all_dtypes()) 6125: dtypesIfCUDA(*get_all_dtypes()) 6163: dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) + 6164: get_all_complex_dtypes())) 6165: dtypesIfCPU(*get_all_dtypes()) 6166: dtypesIfCUDA(*get_all_dtypes()) 6190: dtypes(*(get_all_complex_dtypes() + 6191: get_all_int_dtypes())) 6238: dtypes(*get_all_dtypes()) 6323: dtypes(*get_all_dtypes()) 6389: dtypes(*product(get_all_dtypes(), (torch.uint8, torch.bool))) 6699: dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) 6700: dtypes(*set(get_all_math_dtypes('cpu'))) 7452: dtypes(*get_all_dtypes(include_bool=False)) 7461: dtypes(*get_all_dtypes(include_bool=False)) 7477: dtypes(*get_all_dtypes(include_bool=False)) 7496: dtypes(*get_all_dtypes(include_bool=False)) 7538: dtypes(*get_all_dtypes(include_bool=False)) 8162: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() + 8163: get_all_complex_dtypes())) 8175: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() + 8176: get_all_complex_dtypes())) ``` </p> </details> <details> <summary> `test/test_type_promotion.py`</summary> <p> ```python 14: get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes 187: for dtype in get_all_dtypes(): 262: dtypes1 = get_all_math_dtypes('cuda') 263: dtypes2 = get_all_math_dtypes(device) 339: dtypes(*itertools.product(get_all_dtypes(), get_all_dtypes())) 468: for dt1 in get_all_math_dtypes(device): 469: for dt2 in get_all_math_dtypes(device): 519: for dt1 in get_all_math_dtypes(device): 520: for dt2 in get_all_math_dtypes(device): 528: for dt in get_all_math_dtypes(device): 561: for dtype in get_all_dtypes(): 766: dtypes=get_all_math_dtypes(device)) 771: dtypes=get_all_math_dtypes(device)) 782: dtypes=get_all_math_dtypes(device)) 879: dtypes = get_all_dtypes(include_bfloat16=False) 898: dtypes = get_all_dtypes(include_bfloat16=False, include_bool=False) 965: dtypesIfCUDA(*itertools.product(get_all_dtypes(include_bfloat16=False, include_complex=False), 966: get_all_dtypes(include_bfloat16=False, include_complex=False))) 967: dtypes(*itertools.product(get_all_dtypes(include_half=False, include_bfloat16=False, 969: get_all_dtypes(include_half=False, include_bfloat16=False, 976: return dtype in get_all_int_dtypes() + [torch.bool] 979: return dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False) ``` </p> </details> <details> <summary> `test/test_unary_ufuncs.py`</summary> <p> ```python 24: floating_types_and, all_types_and_complex_and, floating_and_complex_types_and, get_all_dtypes, get_all_math_dtypes, 25: get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes 517: dtypes(*(get_all_int_dtypes() + [torch.bool] + 518: get_all_fp_dtypes(include_bfloat16=False))) 596: dtypes(*get_all_fp_dtypes(include_half=True, include_bfloat16=False)) 611: invalid_input_dtypes = get_all_int_dtypes() + \ 612: get_all_complex_dtypes() + \ 619: for dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False): 1048: dtypes(*get_all_math_dtypes('cpu')) 1182: dtypesIfCUDA(*get_all_fp_dtypes()) 1190: dtypesIfCUDA(*get_all_fp_dtypes()) 1205: dtypesIfCUDA(*get_all_fp_dtypes()) 1215: dtypesIfCUDA(*get_all_fp_dtypes()) 1307: dtypes(*(get_all_dtypes(include_bool=False))) 1349: dtypes(*(get_all_fp_dtypes(include_half=False) + 1350: get_all_complex_dtypes())) 1351: dtypesIfCUDA(*(get_all_fp_dtypes(include_half=True) + 1352: get_all_complex_dtypes())) ``` </p> </details> <details> <summary> `test/test_view_ops.py`</summary> <p> ```python 19: get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes 124: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 131: dtypes(*get_all_dtypes(include_bfloat16=False)) 213: for view_dtype in [*get_all_fp_dtypes(), *get_all_complex_dtypes()]: 220: dtypes(*get_all_dtypes()) 224: for view_dtype in get_all_dtypes(): 305: dtypes(*get_all_complex_dtypes(include_complex32=True)) 343: dtypes(*get_all_dtypes()) 354: dtypes(*get_all_dtypes()) 364: dtypes(*get_all_dtypes()) 374: dtypes(*get_all_dtypes()) 384: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes())) 395: dtypes(*get_all_complex_dtypes()) 426: dtypes(*get_all_complex_dtypes()) 451: dtypes(*product(get_all_complex_dtypes(), get_all_dtypes())) 1263: dtypes(*(torch.testing.get_all_dtypes())) 1279: dtypes(*(torch.testing.get_all_dtypes())) 1405: dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) + 1406: get_all_complex_dtypes())) 1471: dtypes(*get_all_dtypes(include_bfloat16=False)) 1574: dtypes(*get_all_dtypes()) 1601: dtypes(*get_all_dtypes(include_bfloat16=False)) 1632: dtypes(*get_all_dtypes(include_bfloat16=False)) 1711: for dt in get_all_dtypes(): 1717: for dt in get_all_dtypes(): 1724: for dt in get_all_dtypes(): ``` </p> </details> I'm looking forward to your viewpoints. Thanks :) cc: mruberry kshitij12345 anjali411 Pull Request resolved: #71561 Reviewed By: samdow Differential Revision: D34856571 Pulled By: mruberry fbshipit-source-id: 0dca038bcad5cf69906245c496d2e61ac3876335
Hey @khushi-411. |
This pull request has been reverted by ef066f0. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk). |
Reverted the PR as it broke a bunch of test on trunk, see https://hud.pytorch.org/pytorch/pytorch/commit/3ded7b1da349170e1df3e694bbcaabb8639f5fb8 |
Looks like a logical merge conflict with a test that was changed recently, should be a straightforward fix |
Hi! |
@khushi-411 The procedure to merge a reverted PR is to open a new PR with required changes and tagging the same reviewers (as mentioned in #71561 (comment)) Thanks for your patience! |
Got that. I'll open a new PR. Thanks, @kshitij12345! :) |
Created a new PR: #74289 Please take a look :) |
This pull request has been reverted by ef066f0. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk). |
Hi, This PR is the follow-up PR of #71561. (the previous PR had a couple of merge conflicts and was reverted, this PR resolves that). Please take a look. Thanks! cc: @pmeier @mruberry @kshitij12345 Pull Request resolved: #74289 Approved by: https://github.com/pmeier, https://github.com/mruberry
Hi, This PR is the follow-up PR of #71561. (the previous PR had a couple of merge conflicts and was reverted, this PR resolves that). Please take a look. Thanks! cc: @pmeier @mruberry @kshitij12345
Hi, Team!
The PR is motivated from #71153 (comment). It aims to replace
get_all
type macros with the ATen dispatch macros.The files it iterates over are: (Thanks, @lezcano, for the idea!!)
test/test_autograd.py
test/test_binary_ufuncs.py
test/test_complex.py
test/test_foreach.py
test/test_linalg.py
test/test_nn.py
test/test_numpy_interop.py
test/test_ops.py
test/test_reductions.py
test/test_serialization.py
test/test_shape_ops.py
test/test_sort_and_select.py
test/test_sparse.py
test/test_sparse_csr.py
test/test_tensor_creation_ops.py
test/test_testing.py
test/test_torch.py
test/test_type_promotion.py
test/test_unary_ufuncs.py
test/test_view_ops.py
I'm looking forward to your viewpoints. Thanks :)
cc: @mruberry @kshitij12345 @anjali411