8000 AffineQuantizedTensor matmul shape issue · Issue #2069 · pytorch/ao · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
AffineQuantizedTensor matmul shape issue #2069
Closed
@andrewor14

Description

@andrewor14

Minimal repro:

import torch
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization import MappingType

x1 = torch.randn(53, 2048)
w1 = torch.randn(53, 2048)
torch.matmul(x1, w1.t())  # OK

x2 = torch.randn(53, 2048)
w2 = torch.randn(53, 2048)
w2 = to_affine_quantized_intx(
    w2,
    mapping_type=MappingType.SYMMETRIC,
    block_size=(1, 32),
    target_dtype=torch.int8,
    quant_min=-8,
    quant_max=7,
    eps=torch.finfo(torch.float32).eps,
)
torch.matmul(x2, w2.t())  # Fails

Error:

Traceback (most recent call last):
  File "/home/andrewor/local/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 356, in _
    return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
  File "/home/andrewor/local/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 159, in _quantized_linear_op
    raise QuantizedLinearNotImplementedError(
torchao.dtypes.affine_quantized_tensor_ops.QuantizedLinearNotImplementedError: No specialized dispatch found for quantized linear op

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/andrewor/local/ao/torchao/utils.py", line 425, in _dispatch__torch_function__
    return func(*args, **kwargs)
  File "/home/andrewor/local/ao/torchao/utils.py", line 440, in _dispatch__torch_dispatch__
    return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
  File "/home/andrewor/local/ao/torchao/utils.py", line 401, in wrapper
    return func(f, types, args, kwargs)
  File "/home/andrewor/local/ao/torchao/dtypes/affine_quantized_tensor_ops.py", line 370, in _
    return func(input_tensor, weight_tensor)
  File "/home/andrewor/local/pytorch/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (53x2048 and 53x2048)

Seems like regular torch.matmul does an implicit transpose in this case but not when one of the args is an AQT. May need to investigate broadcasting rules for torch.matmul further and replicate them for AQT: https://pytorch.org/docs/stable/generated/torch.matmul.html

Reported by @telgamal-1

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0