8000 [ROCm] Incorrect number of arguments passed to kernel · Issue #140800 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[ROCm] Incorrect number of arguments passed to kernel #140800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jataylo opened this issue Nov 15, 2024 · 5 comments
Closed

[ROCm] Incorrect number of arguments passed to kernel #140800

jataylo opened this issue Nov 15, 2024 · 5 comments
Labels
module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jataylo
Copy link
Collaborator
jataylo commented Nov 15, 2024

🐛 Describe the bug

When using ROCm specific tuning parameters for custom triton kernels we run into an exception.

It seems related to this PR #137236 but simply adding matrix_instr_nonkdim, waves_per_eu and kpack to "SPECIAL_CONFIG_NAMES" does not seem to be enough to fix this as kwargs is empty at this point.

import torch
import triton
from triton import language as tl

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 4, "matrix_instr_nonkdim": 32}, num_stages=3, num_warps=8),
    ],
    key=[],
)
@triton.jit
def add_kernel_autotuned(
    in_ptr0,
    in_ptr1,
    out_ptr,
    n_elements,
    BLOCK_SIZE: "tl.constexpr",
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    y = tl.load(in_ptr1 + offsets, mask=mask)
    output = x + y
    tl.store(out_ptr + offsets, output, mask=mask)

@torch.compile(fullgraph=True)
def add_fn(x, y):
    output = torch.zeros_like(x)
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    add_kernel_autotuned[grid](x, y, output, n_elements)
    return output

x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")

Traceback:

W1115 11:37:10.193439 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0] Encountered an exception in identify_mutated_tensors, assuming every input is mutated
W1115 11:37:10.193439 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0] Traceback (most recent call last):
W1115 11:37:10.193439 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0]   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 574, in identify_mutated_tensors
W1115 11:37:10.193439 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0]     ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
W1115 11:37:10.193439 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0]   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 202, in generate_ttir
W1115 11:37:10.193439 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0]     raise ValueError("Incorrect number of arguments passed to kernel")
W1115 11:37:10.193439 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0] ValueError: Incorrect number of arguments passed to kernel
{'in_ptr0': FakeTensor(..., device='cuda:0', size=(4,)), 'in_ptr1': FakeTensor(..., device='cuda:0', size=(4,)), 'out_ptr': FakeTensor(..., device='cuda:0', size=(4,)), 'n_elements': 4, 'BLOCK_SIZE': 4, 'matrix_instr_nonkdim': 32}
['in_ptr0', 'in_ptr1', 'out_ptr', 'n_elements', 'BLOCK_SIZE']
W1115 11:37:10.193879 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0] Encountered an exception in identify_mutated_tensors, assuming every input is mutated
W1115 11:37:10.193879 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0] Traceback (most recent call last):
W1115 11:37:10.193879 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0]   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 574, in identify_mutated_tensors
W1115 11:37:10.193879 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0]     ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
W1115 11:37:10.193879 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0]   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 202, in generate_ttir
W1115 11:37:10.193879 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0]     raise ValueError("Incorrect number of arguments passed to kernel")
W1115 11:37:10.193879 313 site-packages/torch/_higher_order_ops/triton_kernel_wrap.py:595] [0/0] ValueError: Incorrect number of arguments passed to kernel

Versions

Tip of tree

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @hongxiayang @naromero77amd

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Nov 15, 2024
@jataylo
Copy link
Collaborator Author
jataylo commented Nov 15, 2024

ccing @aakhundov incase anything jumps to your mind on this one. Is this condition https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/triton_kernel_wrap.py#L199 even valid, if we support supplying implicit args via kwargs, then we shouldn't expect kwargs and kernel.arg_names to match

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 16, 2024
@aakhundov
Copy link
Contributor

Thanks for flagging @jataylo, I'll take a look. Is it the intended use case to pass the AMD-specific args like matrix_instr_nonkdim, waves_per_eu and kpack as a part of triton.Config's kwargs dict (first argument) w/o having them in the kernel parameters?

@jataylo
Copy link
Collaborator Author
jataylo commented Nov 18, 2024

@aakhundov Correct we pass the AMD specific args via kwargs, either as implemented in this reproducer or through kwargs of the kernel args e.g.

extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
 
_fwd_kernel[grid](
    q_extend,
    k_extend,
    ...,
    ...,
    **extra_kargs,
)

Both cases run into this issue, I couldn't see a solid way to resolve this in the code without challenging this assumption:

if len(kwargs) != len(kernel.arg_names):
    raise ValueError("Incorrect number of arguments passed to kernel")

aakhundov added a commit that referenced this issue Nov 19, 2024
Fixes #140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.

ghstack-source-id: 084f3e3
Pull Request resolved: #141062
@aakhundov
Copy link
Contributor

@jataylo #141062 should fix the error showing up here.

aakhundov added a commit that referenced this issue Nov 19, 2024
Fixes #140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.

ghstack-source-id: 70037b8
Pull Request resolved: #141062
cyyever pushed a commit to cyyever/pytorch that referenced this issue Nov 20, 2024
…rch#141062)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.
Pull Request resolved: pytorch#141062
Approved by: https://github.com/oulgen
@jataylo
Copy link
Collaborator Author
jataylo commented Nov 20, 2024

Perfect thank you for investigating this @aakhundov!

youssef62 pushed a commit to youssef62/pytorch that referenced this issue Nov 23, 2024
…rch#141062)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.
Pull Request resolved: pytorch#141062
Approved by: https://github.com/oulgen
Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this issue Dec 2, 2024
…rch#141062)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.
Pull Request resolved: pytorch#141062
Approved by: https://github.com/oulgen
jataylo pushed a commit to ROCm/pytorch that referenced this issue Dec 4, 2024
…rch#141062)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.
Pull Request resolved: pytorch#141062
Approved by: https://github.com/oulgen

(cherry picked from commit b740a1b)
pobin6 pushed a commit to pobin6/pytorch that referenced this issue Dec 5, 2024
…rch#141062)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out.
Pull Request resolved: pytorch#141062
Approved by: https://github.com/oulgen
pruthvistony pushed a commit to ROCm/pytorch that referenced this issue Jan 31, 2025
… analysis (#141… (#1768)

Fixes pytorch#140800.

On AMD, backend-specific args like `matrix_instr_nonkdim`,
`waves_per_eu` and `kpack` are passed either direclty to the kernel or
via `triton.Config`, whereas they don't exist as kernel parameters.
Native Triton code handles those excessive args
[here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596).
In this PR, we add similar handling to the TTIR analysis code to avoid
bailing out. Pull Request resolved:
pytorch#141062 Approved by:
https://github.com/oulgen

(cherry picked from commit b740a1b)

Fixes #ISSUE_NUMBER

Co-authored-by: Adnan Akhundov <aakhundov@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

3 participants
0