8000 [RFC] Generalized Per-Operator Device Capability Registry for PyTorch Operator Testing · Issue #154017 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[RFC] Generalized Per-Operator Device Capability Registry for PyTorch Operator Testing #154017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ankurneog opened this issue May 21, 2025 · 0 comments
Labels
module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: PrivateUse1 private use module: rocm AMD GPU support for Pytorch module: testing Issues related to the torch.testing module (not tests) module: xpu Intel XPU related issues rocm This tag is for PRs from ROCm team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ankurneog
Copy link
Contributor
ankurneog commented May 21, 2025

🚀 The feature, motivation and pitch

RFC: Generalized Per-Operator Device Capability Registry for PyTorch Operator Testing

Summary

This RFC proposes the introduction of a Generalized Device Capability Registry to centrally declare and manage per-operator, per-device support information in PyTorch’s testing framework. It aims to replace manual, scattered constructs like dtypesIfCUDA, skips, and xfail in OpInfo declarations with a structured, extensible, and declarative registry.


Motivation

PyTorch's operator testing suite (common_methods_invocations.py) uses constructs like dtypesIfCUDA, skips, and backend-specific decorators directly inside each OpInfo. This approach leads to:

  • Redundant and error-prone boilerplate
  • Difficulty in onboarding new device types or modifying support
  • Scattered device-specific logic across hundreds of operators
  • Limited extensibility for non-dtype constraints (e.g., 0-dim support, layout restrictions)

This RFC proposes a scalable, maintainable abstraction for managing such variations across device backends.


Explanation

Problem

Currently, every operator test in PyTorch manually encodes its device-specific capabilities:

https://github.com/pytorch/pytorch/blob/a636a92ee9f9d31c1ee34416afabdc70da83f75c/torch/testing/_internal/common_methods_invocations.py#L12154

OpInfo('addmm',
       dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
       dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
       dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
       dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
       assert_autodiffed=True,
       supports_forward_ad=True,
       supports_fwgrad_bwgrad=True,
       gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
       sample_inputs_func=sample_inputs_addmm,
       skips=(
           # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
           DecorateInfo(
               unittest.skip("Skipped!"),
               'TestSchemaCheckModeOpInfo',
               'test_schema_correctness',
               dtypes=(torch.complex64, torch.complex128)),
           DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
                        "TestConsistency", "test_output_grad_match", device_type="mps"),
       ))

This approach does not scale as we introduce more ops, devices, dtypes, and constraints.


Proposed Solution

Introduce a device capability registry that allows per-operator support declarations in one centralized location. For example:

cuda_cap = register_device("cuda")
cuda_cap.register_op_dtypes("add", torch.float32, torch.float16, torch.bfloat16)
cuda_cap.add_op_constraint("add", "supports_zero_dim", True)

Operators (via OpInfo) can query this at runtime:

opinfo.get_supported_dtypes(device)  # device: "cuda", "hpu", "cpu"
opinfo.get_constraints(device)

This separates core operator logic from backend variability and supports dynamic test generation, better maintainability, and easier backend integration.


Introduce New Class: DeviceCapability

Sample class for illustration. We should add all capabilities (e.g., dynamic shapes, memory layout, etc.):

class DeviceCapability:
    def __init__(self, device_type: str):
        self.device_type = device_type
        self.op_dtype_support = {}      # op_name -> set of dtypes
        self.op_constraints = {}        # op_name -> { constraint_key: value }

    def register_op_dtypes(self, op_name: str, *dtypes):
        self.op_dtype_support[op_name] = set(dtypes)

    def get_op_dtypes(self, op_name: str):
        return self.op_dtype_support.get(op_name, set())

    def add_op_constraint(self, op_name: str, key: str, value):
        if op_name not in self.op_constraints:
            self.op_constraints[op_name] = {}
        self.op_constraints[op_name][key] = value

    def get_op_constraints(self, op_name: str):
        return self.op_constraints.get(op_name, {})

Devices will register themselves so that OpInfo can reference the capability:

DEVICE_REGISTRY = {}

def register_device(device_type: str) -> DeviceCapability:
    cap = DeviceCapability(device_type)
    DEVICE_REGISTRY[device_type] = cap
    return cap

def get_device_capability(device_type: str) -> DeviceCapability:
    return DEVICE_REGISTRY.get(device_type, DeviceCapability(device_type))

Integration with OpInfo

class GeneralizedOpInfo(OpInfo):
    def get_supported_dtypes(self, device_type):
        return get_device_capability(device_type).get_op_dtypes(self.name)

    def get_constraints(self, device_type):
        return get_device_capability(device_type).get_op_constraints(self.name)

Rationale and conclusion

The current approach has certain limitation such as
- It does not scale to more than a few backends
- It's hard-coded and duplicated in each OpInfo
- It mixes backend quirks with operator metadata

This proposal introduces a structured, centralized mechanism to model device-specific operator support in PyTorch. It improves maintainability, supports backend extensibility, and encourages a more declarative and introspectable test infrastructure.

Alternatives

Alternatives considered
Dynamic test filtering: harder to trace and debug.

JSON/YAML device matrix: brittle and disconnected from test code.

Additional context

This builds upon the RFC introduced for device abstraction in PyTorch frontend pytorch/rfcs#66

cc @ptrblck @msaroufim @eqy @jerryzh168 @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @NmomoN @mengpenghui @fwenguang @cdzhan @1274085042 @PHLens @albanD @gujinghui @EikanWang @fengyuan14 @guangyey

@albanD albanD added module: cuda Related to torch.cuda, and CUDA support in general module: rocm AMD GPU support for Pytorch module: cpu CPU specific problem (e.g., perf, algorithm) module: testing Issues related to the torch.testing module (not tests) rocm This tag is for PRs from ROCm team module: PrivateUse1 private use module: xpu Intel XPU related issues labels May 22, 2025
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: PrivateUse1 private use module: rocm AMD GPU support for Pytorch module: testing Issues related to the torch.testing module (not tests) module: xpu Intel XPU related issues rocm This tag is for PRs from ROCm team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: No status
Development

No branches or pull requests

2 participants
0