Closed
Description
Minimal repro:
import torch
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization import MappingType
x1 = torch.randn(53, 2048)
w1 = torch.randn(53, 2048)
torch.matmul(x1, w1.t()) # OK
x2 = torch.randn(53, 2048)
w2 = torch.randn(53, 2048)
w2 = to_affine_quantized_intx(
w2,
mapping_type=MappingType.SYMMETRIC,
block_size=(1, 32),
target_dtype=torch.int8,
quant_min=-8,
quant_max=7,
eps=torch.finfo(torch.float32).eps,
)
torch.matmul(x2, w2.t()) # Fails
Error:
Traceback (most recent call last):
File "/home/andrewor/local/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 356, in _
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
File "/home/andrewor/local/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 159, in _quantized_linear_op
raise QuantizedLinearNotImplementedError(
torchao.dtypes.affine_quantized_tensor_ops.QuantizedLinearNotImplementedError: No specialized dispatch found for quantized linear op
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/andrewor/local/ao/torchao/utils.py", line 425, in _dispatch__torch_function__
return func(*args, **kwargs)
File "/home/andrewor/local/ao/torchao/utils.py", line 440, in _dispatch__torch_dispatch__
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
File "/home/andrewor/local/ao/torchao/utils.py", line 401, in wrapper
return func(f, types, args, kwargs)
File "/home/andrewor/local/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 370, in _
return func(input_tensor, weight_tensor)
File "/home/andrewor/local/pytorch/torch/_ops.py", line 756, in __call__
return self._op(*args, **kwargs)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (53x2048 and 53x2048)
Seems like regular torch.matmul
does an implicit transpose in this case but not when one of the args is an AQT. May need to investigate broadcasting rules for torch.matmul
further and replicate them for AQT: https://pytorch.org/docs/stable/generated/torch.matmul.html
Reported by @telgamal-1
Metadata
Metadata
Assignees
Labels
No labels