-
Notifications
You must be signed in to change notification settings - Fork 24.1k
[ROCm] Incorrect number of arguments passed to kernel #140800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
ccing @aakhundov incase anything jumps to your mind on this one. Is this condition https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/triton_kernel_wrap.py#L199 even valid, if we support supplying implicit args via kwargs, then we shouldn't expect kwargs and kernel.arg_names to match |
Thanks for flagging @jataylo, I'll take a look. Is it the intended use case to pass the AMD-specific args like |
@aakhundov Correct we pass the AMD specific args via kwargs, either as implemented in this reproducer or through kwargs of the kernel args e.g.
Both cases run into this issue, I couldn't see a solid way to resolve this in the code without challenging this assumption:
|
Fixes #140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. ghstack-source-id: 084f3e3 Pull Request resolved: #141062
Fixes #140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. ghstack-source-id: 70037b8 Pull Request resolved: #141062
…rch#141062) Fixes pytorch#140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: pytorch#141062 Approved by: https://github.com/oulgen
Perfect thank you for investigating this @aakhundov! |
…rch#141062) Fixes pytorch#140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: pytorch#141062 Approved by: https://github.com/oulgen
…rch#141062) Fixes pytorch#140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: pytorch#141062 Approved by: https://github.com/oulgen
…rch#141062) Fixes pytorch#140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: pytorch#141062 Approved by: https://github.com/oulgen (cherry picked from commit b740a1b)
…rch#141062) Fixes pytorch#140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: pytorch#141062 Approved by: https://github.com/oulgen
… analysis (#141… (#1768) Fixes pytorch#140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: pytorch#141062 Approved by: https://github.com/oulgen (cherry picked from commit b740a1b) Fixes #ISSUE_NUMBER Co-authored-by: Adnan Akhundov <aakhundov@fb.com>
🐛 Describe the bug
When using ROCm specific tuning parameters for custom triton kernels we run into an exception.
It seems related to this PR #137236 but simply adding matrix_instr_nonkdim, waves_per_eu and kpack to "SPECIAL_CONFIG_NAMES" does not seem to be enough to fix this as kwargs is empty at this point.
Traceback:
Versions
Tip of tree
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @hongxiayang @naromero77amd
The text was updated successfully, but these errors were encountered: