8000 torchao serialization · Issue #28 · mobiusml/gemlite · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

torchao serialization #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jerryzh168 opened this issue Apr 16, 2025 · 1 comment
Open

torchao serialization #28

jerryzh168 opened this issue Apr 16, 2025 · 1 comment

Comments

@jerryzh168
Copy link

Seems there are some problems when trying to serialize gemlite quantized model in torchao? I got:

_pickle.PicklingError: Can't pickle <class 'gemlite.triton_kernels.gemv_A16fWnO16f_int32packing.gemv_A16fWnO16f'>: it's not the same object as gemlite.triton_kernels.gemv_A16fWnO16f_int32packing.gemv_A16fWnO16f

using: https://gist.github.com/jerryzh168/85b4afa959e37f3a84236aeedef1df1a

@mobicham
Copy link
Collaborator

Hi Jerry, thanks for reporting the issue. It should be fixed by now in the master branch: 0bedbe7 , 9f55e2e

State dict test:

import torch 
from gemlite.helper import *

in_features, dtype, device = 4096, torch.float16, 'cuda:0'

def patch_linearlayers(model, fct):
    for name, layer in model.named_children():
        if isinstance(layer, torch.nn.Linear):
            setattr(model, name, fct(layer, name))
        else:
            patch_linearlayers(layer, fct)

def patch_linear_to_gemlite(layer, name):
    if(min(layer.in_features, layer.out_features) % 64 != 0):
        return layer.to(device)
    else:
        return A16W8(device=device).from_linear(layer) 

def _quantize(model):
    patch_linearlayers(model, patch_linear_to_gemlite)
    return model

def create_model(in_features):
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features, in_features),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features, in_features),
        torch.nn.ReLU(),
        torch.nn.Linear(in_features, in_features), 
        torch.nn.ReLU(),
        torch.nn.Linear(in_features, in_features)  
    )

    return model.to(dtype=dtype, device=device)


torch.manual_seed(0)
model = create_model(in_features)
x = torch.randn((1, in_features), dtype=dtype, device=device)/10.
y_ref = model(x)

#Check difference between quantized and unquantized
_quantize(model)
y_pred = model(x)
assert (y_ref - y_pred).abs().mean() <= 1e-3, 'Values mismatch between quantized and unquantized'

#Model state dict
torch.save(model.state_dict(), 'quant_model')
state_dict = torch.load('quant_model', map_location=device, weights_only=True)

#Load state_dict and check the different between the original quantized model and the loaded  
torch.manual_seed(100)
model2 = _quantize(create_model(in_features))
model2.load_state_dict(state_dict)
y_pred2 = model2(x)
assert (y_pred2 - y_pred).abs().mean() <= 1e-5, 'Values mismatch after state_dict loading'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
0