8000 How to Insert High-Dimensional Matrix data without Protobuf Read Errors · Issue #6737 · onnx/onnx · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

How to Insert High-Dimensional Matrix data without Protobuf Read Errors #6737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zhudianGG opened this issue Feb 28, 2025 · 11 comments
Open
Labels
question Questions about ONNX

Comments

@zhudianGG
Copy link

Bug Report

Is the issue related to model conversion?

No, this issue is not related to model conversion. It occurs during the process of modifying an existing ONNX model by inserting nodes.

Describe the bug

I encountered an issue when trying to insert multiple nodes with high-dimensional matrices into an ONNX model of a large language model. Specifically, when inserting nodes with matrices of dimension [11008, 11008] in fp32 format, I can only insert them before the down_proj matmul nodes in up to 4 layers of the LLaMA 2 7B model. However, I can insert smaller matrices, such as [4096, 4096], in all layers without issue. This limitation seems to cause Protobuf to fail to read the model, suggesting a potential issue with ONNX's support for handling large matrices.

error:
in deserialize_proto
decoded = typing.cast(Optional[int], proto.ParseFromString(serialized))
google.protobuf.message.DecodeError: Error parsing message with type 'onnx.ModelProto'

insert function:

 def insert_matmul_op(self, target_node_name, weight_matrix):
        # Find the target node  
        target_node_index = None
        for i, node in enumerate(self.graph.node):
            if node.name == target_node_name:
                target_node_index = i
                break

        if target_node_index is None:
            raise ValueError(f"Node {target_node_name} not found in the model.")

        target_node = self.graph.node[target_node_index]

        # Build a new MatMul node and insert it  
        new_node_name = target_node.name + str(self.node_count) + "_pre_matmul"
        self.node_count += 1
        node_tensor_name = new_node_name + ".weight"
        out_name = new_node_name + "/output"

        # Create a tensor for the weight matrix  
        node_tensor_proto = helper.make_tensor(
            name=node_tensor_name,
            data_type=TensorProto.FLOAT16,
            dims=weight_matrix.shape,
            #vals=weight_matrix.flatten().tolist()
            vals=weight_matrix.flatten().view(np.uint16).tolist()
        )

        # Print the shape of the weight matrix  
        print(f"Weight matrix shape: {weight_matrix.shape}")
   # Create the new MatMul node  
        new_node = helper.make_node(
            "MatMul",
            inputs=[target_node.input[0], node_tensor_name],
            outputs=[out_name],
            name=new_node_name
        )

        # Insert the new node before the target node  
        self.graph.node.insert(target_node_index - 1, new_node)

        # Add the initializer for the new node  
        self.graph.initializer.append(node_tensor_proto)

        # Reorganize target_node to take the output of the new node as input  
        target_node.input[0] = out_name

System information

Expected behavior

I expected to be able to insert nodes with high-dimensional matrices into all layers of the model without encountering Protobuf read errors, similar to the behavior observed with smaller matrices.

Notes

Changing the format to fp16 allows insertion into 5-6 layers.
The maximum number of insertable nodes is not affected by the order of the layers.
Attempting to replace a large matrix with multiple smaller matrices still results in a limitation of inserting only up to 4 nodes in the same layer.
Any insights or suggestions on how to address this issue would be greatly appreciated.

@zhudianGG zhudianGG added the bug label Feb 28, 2025
@justinchuby justinchuby added question Questions about ONNX and removed bug labels Feb 28, 2025
@justinchuby justinchuby changed the title [Bug]Limitation on Inserting High-Dimensional Matrix Nodes in ONNX Models Causes Protobuf Read Errors How to Insert High-Dimensional Matrix data without Protobuf Read Errors Feb 28, 2025
@justinchuby
Copy link
Contributor

A protobuf message cannot exceed the size of 2GB. You can leverage the external tensor option to store the data outside of the protobuf file:

save_as_external_data: bool = False,

@zhudianGG
Copy link
Author

Hi @justinchuby ,

I have already tried saving the model by:
def save_model(self, output_path):
onnx.save(self.onnx_model, output_path, save_as_external_data=True, all_tensors_to_one_file=True, location="model_data", size_threshold=1024)
print(f"Modified model saved to {output_path}")

same error then

@justinchuby
Copy link
Contributor

When you load, did you set load_external_data to False?

@justinchuby
Copy link
Contributor

You can also look into using onnxscript.ir for this

@zhudianGG
Copy link
Author
zhudianGG commented Feb 28, 2025

Hi @justinchuby,

Yes, I tried this in init by:
self.onnx_model = onnx.load(model_path, load_external_data=False)
load_external_data_for_model(self.onnx_model, "model_data_file")

unfortunately it didn't work, and it seems I should try onnxscript.ir you mentioned as examples in (https://github.com/microsoft/onnxscript/tree/main/examples/pattern_rewriting.py):

model = get_rotary_model(True)
ir_model = ir.serde.deserialize_model(model)

rule.apply_to_model(ir_model)
rewritten_model = ir.serde.serialize_model(ir_model)

Or did I misunderstand?

@zhudianGG
Copy link
Author

Hi @justinchuby,

Sorry to bother, but could you please provide a possible example using onnxscript.ir to insert a customized MatMul node, I would be really appreciate~

@justinchuby
Copy link
Contributor
justinchuby commented Mar 5, 2025
from onnxscript import ir
import numpy as np

def insert_matmul_op(model: ir.Model, target_node_name: str, weight_matrix: np.ndarray):
    # A -> Target
    # {A, initializer} -> MatMul -> Target
    # Find the target node
    target_node = model.graph.node(target_node_name)
    tensor = ir.tensor(weight_matrix, name=f"{node.inputs[0].name}_initializer")
    initializer = ir.Input(tensor.name, tensor.shape, ir.TensorType(tensor.dtype))
    initializer.const_value = tensor
    model.graph.register_initializer(initializer)
    # Print the shape of the weight matrix
    print(f"Weight matrix shape: {weight_matrix.shape}")
    # Create the new MatMul node
    new_node = ir.Node("", "MatMul", inputs=[target_node.inputs[0], initializer])
    new_node.outputs[0].name = f"{node.inputs[0].name}_mul"
    target_node.prepend(new_node)
    target_node.replace_input_with(0, new_node.outputs[0])

@zhudianGG
Copy link
Author
zhudianGG commented Mar 7, 2025

Hi @justinchuby ,

It seems the code didn't set the node's const value(File ".../lib/python3.10/site-packages/onnxscript/ir/_core.py", line 1973, in register_initializer
raise ValueError(
ValueError: Value 'Value('/model/layers.0/post_attention_layernorm/Mul_1_output_0_initializer_1', type=Tensor(1), shape=[4096,5504], producer=None, index=None)' must have its const_value set to be an initializer.), I tried implemented by following code, but it ran into similar outcome:

import onnx
from onnx import helper, TensorProto
import numpy as np
from onnx.external_data_helper import load_external_data_for_model
from onnxscript import ir

class ONNXModelModifier:
    def __init__(self, model_path):
        self.onnx_model = onnx.load(model_path, load_external_data=False)
        load_external_data_for_model(self.onnx_model, "model_data_file")
        self.graph = self.onnx_model.graph
        self.node_count = 1
        self.ir_model = ir.serde.deserialize_model(self.onnx_model)

    def insert_matmul_op_ir(self, target_node_name, weight_matrix):
        # Find t
8000
he target node
        target_node = None
        for node in self.graph.node:
            if node.name == target_node_name:
                target_node = node
                break

        if target_node is None:
            raise ValueError(f"Node {target_node_name} not found in the model.")

        # Create a unique name for the new initializer
        initializer_name = f"{target_node.input[0]}_initializer_{self.node_count}"

        # Create the initializer for the weight matrix
        tensor = ir.tensor(weight_matrix, name=initializer_name)
        initializer = ir.Input(initializer_name, tensor.shape, ir.TensorType(TensorProto.FLOAT))
        self.ir_model.graph.register_initializer(initializer)
        #self.graph.initializer(initializer)

        # Create the new MatMul node
        new_node_name = f"{target_node.name}_pre_matmul_{self.node_count}"
        new_node = ir.Node("", "MatMul", inputs=[target_node.input[0], initializer_name], outputs=[new_node_name])
        self.graph.node.insert(self.graph.node.index(target_node), new_node)

        # Update the target node's input to use the output of the new MatMul node
        target_node.input[0] = new_node.output[0]

        # Increment the node count
        self.node_count += 1

    def save_model(self, output_path):
        self.onnx_model = ir.serde.serialize_model(self.ir_model)
        onnx.save(self.onnx_model, output_path, save_as_external_data=True, all_tensors_to_one_file=True, location="model_data", size_threshold=1024)
        print(f"Modified model saved to {output_path}")

model_modifier = ONNXModelModifier('model.onnx')
weight_matrix0 = np.random.rand(4096, 5504).astype(np.float32)  # Example weight matrix
weight_matrix1 = np.random.rand(5504, 5504).astype(np.float32)  # Example weight matrix
weight_matrix2 = np.random.rand(5504, 4096).astype(np.float32)  # Example weight matrix

Could you help me with this situation or do you have other advice for me? I would be really appreciate.

@justinchuby
Copy link
Contributor
justinchuby commented Mar 7, 2025

Updated code. Please check: I missed the line to assign the tensor initializer.const_value = tensor Also the way you did self.graph.node.insert is not part of the IR apis. You may refer to the snippet above for the proper usage: target_node.prepend(new_node)

@zhudianGG
Copy link
Author

Hi @justinchuby, thanks for your example, but it seems will result in same outcome. I will provide the screenshot and error later, I am trying to insert nodes in smaller models like phi to test and finish my task, if you have better solutions or other advice, welcome to share~

@justinchuby
Copy link
Contributor

You may leverage https://github.com/microsoft/onnxscript/blob/main/onnxscript/ir/external_data.py when working with ONNX IR to externalize the big weights. If you save the model as is there's likely going to be a problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions about ONNX
Projects
None yet
Development

No branches or pull requests

2 participants
0