Closed
Description
🚀 The feature, motivation and pitch
There are GEMM loggings but they don't work well for torch.bmm right now.
repro:
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch
M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)
compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)
print("done")
logs:
torch/_inductor/kernel/bmm.py:182] [0/0] Tuned aten.bmm: m=1024, n=1024, k=1024, mat1_dtype=torch.bfloat16, mat2_dtype=torch.bfloat16, output_layout=FixedLayout('cuda:0', torch.bfloat16, size=[10, 1024, 1024], stride=[1048576, 1024, 1])
torch/_inductor/compile_fx.py:998] [0/0] Overview info of inductor aten mms:
torch/_inductor/compile_fx.py:999] [0/0] Name | M | N | K | Count
torch/_inductor/compile_fx.py:1004] [0/0] ----------------------------------------------------------------------------------------------------
torch/_inductor/compile_fx.py:1006] [0/0] aten.bmm | 1024 | 1024 | 1024 | 1
torch/_inductor/compile_fx.py:1007] [0/0] ----------------------------------------------------------------------------------------------------
But the logs don't reflect the batch size B. Try to fix that.
Alternatives
No response
Additional context
No response
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov