8000 Request for Official HF → OLMoE Checkpoint Conversion Script · Issue #26 · allenai/OLMoE · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Request for Official HF → OLMoE Checkpoint Conversion Script #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
aditi184 opened this issue Mar 2, 2025 · 10 comments
Open

Request for Official HF → OLMoE Checkpoint Conversion Script #26

aditi184 opened this issue Mar 2, 2025 · 10 comments

Comments

@aditi184
Copy link
aditi184 commented Mar 2, 2025

Hi OLMo team,

I'm currently working on converting a Hugging Face model (allenai/OLMoE-1B-7B-0924-Instruct) into OLMo/OLMoE's pretraining checkpoint format to resume pretraining. While I was able to convert the model weights, I encountered issues with missing metadata.json, which seems to be a critical component for loading the checkpoint in restore_checkpoint().

After examining the metadata.json generated when training OLMo from scratch, I realized that it contains non-trivial fields and additional metadata about the architecture and checkpoint format. Reconstructing this file seems error-prone, and I couldn't find documentation or scripts for safely performing this conversion.

Would it be possible for you to provide an official Hugging Face → OLMo checkpoint conversion script, or any guidance on the exact format required for metadata.json? This would be extremely helpful for me as I am working on a project that requires continual pretraining of the OLMoE- Instruct model using OLMo from existing HF checkpoints.

Alternately, it would also be helpful if you can share the OLMoE checkpoint that is accepted by the pretraining script (https://github.com/allenai/OLMo/blob/Muennighoff/MoE/scripts/train.py)

Thanks for your help!

@Muennighoff
Copy link
Collaborator

I realized that it contains non-trivial fields and additional metadata about the architecture and checkpoint format

Can you share which you are referring to? Maybe you could also share the script you already wrote? I'm not sure atm if such a script exists but tagging some people in case they know cc @soldni @swj0419

@aditi184
Copy link
Author
aditi184 commented Mar 5, 2025

Hey @Muennighoff, I meant to load the pretrained checkpoint to continually pretrain it futher, the model is needed to be in the OLMo class format right?

I wrote the code to convert it to that class using the attached script.

Run script using python convert_hf_to_olmo.py --hf_model_path allenai/OLMoE-1B-7B-0924-Instruct --output_path /home/mila/k/khandela/scratch/ai2-llm/checkpoints/OLMoE/base --tokenizer_path allenai/OLMoE-1B-7B-0924


import os
import torch
import yaml
import argparse
from pathlib import Path
from transformers import OlmoeForCausalLM, AutoTokenizer

def convert_hf_to_olmo(hf_model_path, output_path, tokenizer_path=None):
    os.makedirs(output_path, exist_ok=True)
    print(f"Loading Hugging Face model from {hf_model_path}")
    model = OlmoeForCausalLM.from_pretrained(hf_model_path)
    state_dict = model.state_dict()
    
    olmo_state_dict = {}
    
    print("Reformatting weights to OLMo pretraining format...")
    for key, value in state_dict.items():
        new_key = key.replace("model.layers", "transformer.blocks").replace(".self_attn.", ".att_proj.")
        new_key = new_key.replace("q_proj", "q_proj").replace("k_proj", "k_proj").replace("v_proj", "v_proj")
        new_key = new_key.replace("o_proj", "attn_out").replace("gate_proj", "ffn.router.layer")
        new_key = new_key.replace("up_proj", "ffn.experts.mlp.v1").replace("down_proj", "ffn.experts.mlp.w2")
        new_key = new_key.replace("input_layernorm", "attn_norm").replace("post_attention_layernorm", "ff_norm")
        
        olmo_state_dict[new_key] = value
    
    torch.save(olmo_state_dict, os.path.join(output_path, "model.pt"))
    print(f"Converted model saved to {output_path}/model.pt")
    
    config = {
        "model": {
            "n_layers": model.config.num_hidden_layers,
            "n_heads": model.config.num_attention_heads,
            "d_model": model.config.hidden_size,
            "max_sequence_length": model.config.max_position_embeddings,
            "embedding_size": model.config.vocab_size,
            "pad_token_id": model.config.pad_token_id,
            "eos_token_id": model.config.eos_token_id,
            "weight_tying": model.config.tie_word_embeddings
        }
    }
    with open(os.path.join(output_path, "config.yaml"), "w") as f:
        yaml.dump(config, f)
    print(f"Config file saved to {output_path}/config.yaml")
    
    if tokenizer_path:
        tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
        tokenizer.save_pretrained(output_path)
        print(f"Tokenizer saved to {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--hf_model_path", required=True, help="Path to the Hugging Face model.")
    parser.add_argument("--output_path", required=True, help="Directory to save the converted OLMo model.")
    parser.add_argument("--tokenizer_path", default=None, help="Optional: Path to Hugging Face tokenizer.")
    args = parser.parse_args()
    
    convert_hf_to_olmo(args.hf_model_path, args.output_path, args.tokenizer_path)

I wanted to know if the above code I wrote is safe for this conversion as I want to pretrain the allenai/OLMoE-1B-7B-0924-Instruct checkpoint from HF.

After converting this way, and checking in the model path in the pretraining code, it expects a file metadata.json which I am attaching:

metadata.json

I am not sure how to create this file.

@aditi184
Copy link
Author
aditi184 commented Mar 5, 2025

Basically I want to know how can I continually pretrain allenai/OLMoE-1B-7B-0924-Instruct checkpoint from HF using the available OLMoE pretraining code.

@Muennighoff
Copy link
Collaborator

Looks pretty good! I think the metadata file specifies some shapes for FSDP - to recreate it I think you can simply train a dummy model from scratch for 1 step with the FSDP distribution setup you want, save its checkpoint and take its metadata.json

@aditi184
Copy link
Author
aditi184 commented Mar 5, 2025

I actually realized that the script I pasted above is imperfect and leads to errors. I tried improving but couldn't suceed in loading the converted model using:

from olmo.model import OLMo
import torch
model = OLMo.from_checkpoint("/home/mila/k/khandela/scratch/ai2-llm/checkpoints/OLMoE/test")

Now I got more confused with a few things, for pretraining OLMoE, the model class still remains OLMo? Are there any new keys introduced for OLMoE?

Or am I loading it incorrectly?

Updated script I tried:

import argparse
import os
import torch
import yaml
import json
import gc
from pathlib import Path
from transformers import OlmoeForCausalLM, OlmoeConfig

def write_config_yaml(output_dir, config_dict):
    os.makedirs(output_dir, exist_ok=True)
    config_path = Path(output_dir) / "config.yaml"
    # Save the config under the "model" key as expected by OLMoE.
    with open(config_path, "w") as f:
        yaml.dump({"model": config_dict}, f, default_flow_style=False)

def convert_hf_to_olmoe(hf_dir, output_dir):
    print(f"Loading HF checkpoint from {hf_dir} …")
    model = OlmoeForCausalLM.from_pretrained(hf_dir, torch_dtype=torch.bfloat16)
    hf_state = model.state_dict()
    config: OlmoeConfig = model.config

    n_layers = config.num_hidden_layers
    n_heads = config.num_attention_heads
    hidden_size = config.hidden_size
    dims_per_head = hidden_size // n_heads
    base = config.rope_theta if getattr(config, "rope_theta", None) is not None else 10000.0
    max_position_embeddings = config.max_position_embeddings
    vocab_size = config.vocab_size
    pad_token_id = config.pad_token_id
    eos_token_id = config.eos_token_id

    # Determine key-value heads use
8000
d (for multi-query or GQA)
    num_key_value_heads = config.num_key_value_heads if getattr(config, "num_key_value_heads", None) is not None else n_heads

    olmoe_state = {}

    # Process each transformer block/layer.
    for layer_i in range(n_layers):
        # --- Attention projection ---
        # HF model has separate q_proj, k_proj, v_proj weights.
        q_proj = hf_state[f"model.layers.{layer_i}.self_attn.q_proj.weight"]
        k_proj = hf_state[f"model.layers.{layer_i}.self_attn.k_proj.weight"]
        v_proj = hf_state[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
        # Concatenate to form the original combined weight.
        att_proj_weight = torch.cat([q_proj, k_proj, v_proj], dim=0)
        key_att_proj = f"transformer.blocks.{layer_i}.att_proj.weight"
        olmoe_state[key_att_proj] = att_proj_weight
        # Create a bias for att_proj (if original expects it) as zeros.
        key_att_proj_bias = f"transformer.blocks.{layer_i}.att_proj.bias"
        olmoe_state[key_att_proj_bias] = torch.zeros(att_proj_weight.size(0), dtype=att_proj_weight.dtype)

        # --- Attention output projection ---
        o_proj_weight = hf_state[f"model.layers.{layer_i}.self_attn.o_proj.weight"]
        key_attn_out = f"transformer.blocks.{layer_i}.attn_out.weight"
        olmoe_state[key_attn_out] = o_proj_weight
        key_attn_out_bias = f"transformer.blocks.{layer_i}.attn_out.bias"
        olmoe_state[key_attn_out_bias] = torch.zeros(o_proj_weight.size(0), dtype=o_proj_weight.dtype)

        # --- Layer normalization ---
        # In the forward conversion, "input_layernorm.weight" was saved as attn_norm.weight.
        attn_norm_weight = hf_state[f"model.layers.{layer_i}.input_layernorm.weight"]
        key_attn_norm = f"transformer.blocks.{layer_i}.attn_norm.weight"
        olmoe_state[key_attn_norm] = attn_norm_weight
        key_attn_norm_bias = f"transformer.blocks.{layer_i}.attn_norm.bias"
        olmoe_state[key_attn_norm_bias] = torch.zeros(attn_norm_weight.size(0), dtype=attn_norm_weight.dtype)

        # Similarly, "post_attention_layernorm.weight" becomes ff_norm.weight.
        ff_norm_weight = hf_state[f"model.layers.{layer_i}.post_attention_layernorm.weight"]
        key_ff_norm = f"transformer.blocks.{layer_i}.ff_norm.weight"
        olmoe_state[key_ff_norm] = ff_norm_weight
        key_ff_norm_bias = f"transformer.blocks.{layer_i}.ff_norm.bias"
        olmoe_state[key_ff_norm_bias] = torch.zeros(ff_norm_weight.size(0), dtype=ff_norm_weight.dtype)

        # --- Feed-forward (FFN) router ---
        router_weight = hf_state[f"model.layers.{layer_i}.mlp.gate.weight"]
        key_ff_router = f"transformer.blocks.{layer_i}.ffn.router.layer.weight"
        olmoe_state[key_ff_router] = router_weight
        key_ff_router_bias = f"transformer.blocks.{layer_i}.ffn.router.layer.bias"
        # If the original did not have a bias, fill with zeros.
        olmoe_state[key_ff_router_bias] = torch.zeros(router_weight.size(0), dtype=router_weight.dtype)

        # --- Experts ---
        # Determine number of experts by checking available expert keys.
        num_experts = 0
        while f"model.layers.{layer_i}.mlp.experts.{num_experts}.gate_proj.weight" in hf_state:
            num_experts += 1

        expert_w1_list = []
        expert_v1_list = []
        expert_w2_list = []
        for expert_i in range(num_experts):
            gate_proj = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_i}.gate_proj.weight"]
            expert_w1_list.append(gate_proj)
            up_proj = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_i}.up_proj.weight"]
            expert_v1_list.append(up_proj)
            # down_proj was transposed during the forward conversion; reverse that here.
            down_proj = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_i}.down_proj.weight"]
            expert_w2_list.append(down_proj.t().contiguous())
        # Concatenate experts’ weights along the first dimension.
        key_experts_w1 = f"transformer.blocks.{layer_i}.ffn.experts.mlp.w1"
        olmoe_state[key_experts_w1] = torch.cat(expert_w1_list, dim=0)
        key_experts_v1 = f"transformer.blocks.{layer_i}.ffn.experts.mlp.v1"
        olmoe_state[key_experts_v1] = torch.cat(expert_v1_list, dim=0)
        key_experts_w2 = f"transformer.blocks.{layer_i}.ffn.experts.mlp.w2"
        olmoe_state[key_experts_w2] = torch.cat(expert_w2_list, dim=0)

        # --- Rotary embeddings ---
        # Recompute inv_freq using dims_per_head and base.
        inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2, dtype=torch.float32) / dims_per_head))
        key_inv_freq = f"transformer.blocks.{layer_i}.self_attn.rotary_emb.inv_freq"
        olmoe_state[key_inv_freq] = inv_freq

        # --- (Optional) Feed-forward output projections ---
        # Some OLMoE implementations expect an additional FFN output projection.
        # If missing, we fill with zeros. Here we assume the expected shape is (hidden_size, hidden_size).
        key_ff_out_weight = f"transformer.blocks.{layer_i}.ff_out.weight"
        key_ff_out_bias = f"transformer.blocks.{layer_i}.ff_out.bias"
        olmoe_state[key_ff_out_weight] = torch.zeros(hidden_size, hidden_size, dtype=hf_state[f"model.layers.{layer_i}.mlp.gate.weight"].dtype)
        olmoe_state[key_ff_out_bias] = torch.zeros(hidden_size, dtype=hf_state[f"model.layers.{layer_i}.mlp.gate.weight"].dtype)

        # Similarly, if a feed-forward projection is expected.
        key_ff_proj_weight = f"transformer.blocks.{layer_i}.ff_proj.weight"
        key_ff_proj_bias = f"transformer.blocks.{layer_i}.ff_proj.bias"
        olmoe_state[key_ff_proj_weight] = torch.zeros(hidden_size, hidden_size, dtype=hf_state[f"model.layers.{layer_i}.mlp.gate.weight"].dtype)
        olmoe_state[key_ff_proj_bias] = torch.zeros(hidden_size, dtype=hf_state[f"model.layers.{layer_i}.mlp.gate.weight"].dtype)

    # --- Final (root) weights ---
    # Embedding tokens.
    olmoe_state["transformer.wte.weight"] = hf_state["model.embed_tokens.weight"]
    # Final LM head.
    olmoe_state["transformer.ff_out.weight"] = hf_state["lm_head.weight"]
    # Final layer norm.
    olmoe_state["transformer.ln_f.weight"] = hf_state["model.norm.weight"]
    olmoe_state["transformer.ln_f.bias"] = torch.zeros(hf_state["model.norm.weight"].size(0), dtype=hf_state["model.norm.weight"].dtype)

    # Save the OLMoE checkpoint as a single file "model.pt"
    os.makedirs(output_dir, exist_ok=True)
    output_model_path = os.path.join(output_dir, "model.pt")
    torch.save(olmoe_state, output_model_path)
    print(f"Saved OLMoE model weights to {output_model_path}")

    # --- Recreate the OLMoE config.yaml ---
    olmoe_config = {
        "n_layers": n_layers,
        "n_heads": n_heads,
        "d_model": hidden_size,
        "max_sequence_length": max_position_embeddings,
        "pad_token_id": pad_token_id,
        "eos_token_id": eos_token_id,
        "vocab_size": vocab_size,
        "weight_tying": config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") else True,
        "multi_query_attention": True if num_key_value_heads == 1 else False,
        "clip_qkv": getattr(config, "clip_qkv", None),
    }
    write_config_yaml(output_dir, olmoe_config)
    print(f"Saved OLMoE config to {os.path.join(output_dir, 'config.yaml')}")

    # Cleanup.
    del olmoe_state, hf_state
    gc.collect()

def main():
    parser = argparse.ArgumentParser(
        description="Convert an HF (Transformers) OLMoE checkpoint to the original OLMoE format."
    )
    parser.add_argument("--input_dir", required=True, help="Path to the HF checkpoint directory (contains config.json, pytorch_model.bin, etc.).")
    parser.add_argument("--output_dir", required=True, help="Path where the OLMoE checkpoint (model.pt and config.yaml) will be saved.")
    args = parser.parse_args()
    convert_hf_to_olmoe(args.input_dir, args.output_dir)

if __name__ == "__main__":
    main()

@Muennighoff
Copy link
Collaborator

the model class still remains OLMo

yes! There are just some new keys related to the experts and gating that you should have from the ckpt

@aditi184
Copy link
Author
aditi184 commented Mar 7, 2025

Hey Niklas

Upon multiple trials to convert the HF checkpoint to OLMo class, I still am encountering issues like:

Missing key(s) in state_dict: "transformer.wpe.weight".
Unexpected key(s) in state_dict: "transformer.blocks.0.q_norm.weight", "transformer.blocks.0.k_norm.weight", "transformer.blocks.1.q_norm.weight", "transformer.blocks.1.k_norm.weight", "transformer.blocks.2.q_norm.weight", "transformer.blocks.2.k_norm.weight", "transformer.blocks.3.q_norm.weight", "transformer.blocks.3.k_norm.weight"

and shape mismatch errors.

Although I printed all shapes and they look fine, still the script isn't working.

Can you provide the checkpoint in the OLMo format so that I can load it directly? like there are checkpoints available on HuggingFace which are loaded without transformers library. Or if you can provide the script.

# Run script using python convert_hf_to_olmo2.py --hf_model_path allenai/OLMoE-1B-7B-0924-Instruct --output_dir /home/mila/k/khandela/scratch/ai2-llm/checkpoints/OLMoE/test

import argparse
import json
import os
import torch
import yaml
from pathlib import Path
from transformers import OlmoeForCausalLM, OlmoeConfig
import ipdb
def write_json(text, path):
    with open(path, "w") as f:
        json.dump(text, f, indent=4)

def save_olmo_checkpoint(model_path, hf_model_path):
    os.makedirs(model_path, exist_ok=True)
    tmp_model_path = os.path.join(model_path, "tmp")
    os.makedirs(tmp_model_path, exist_ok=True)
    
    # Load HF model
    hf_checkpoint = OlmoeForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
    hf_config = hf_checkpoint.config
    print(hf_config)
    n_layers = hf_config.num_hidden_layers
    n_heads = hf_config.num_attention_heads
    dim = hf_config.hidden_size
    dims_per_head = dim // n_heads
    num_key_value_heads = hf_config.num_key_value_heads
    num_experts = hf_config.intermediate_size // (dim // num_key_value_heads)
    print(num_experts)  # Estimate number of experts
    base = hf_config.rope_theta

    state_dict = hf_checkpoint.state_dict()
    
    olmoe_state_dict = {}
    
    # Convert transformer layers
    for layer_i in range(hf_config.num_hidden_layers):
        q_proj_weight = state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"]
        k_proj_weight = state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"]
        v_proj_weight = state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
        
        # Fuse Q, K, V into a single tensor (OLMoE format)
        att_proj_weight = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)

        # Map back to OLMoE format
        olmoe_state_dict[f"transformer.blocks.{layer_i}.att_proj.weight"] = att_proj_weight
        olmoe_state_dict[f"transformer.blocks.{layer_i}.attn_out.weight"] = state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"]
        olmoe_state_dict[f"transformer.blocks.{layer_i}.q_norm.weight"] = state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"]
        olmoe_state_dict[f"transformer.blocks.{layer_i}.k_norm.weight"] = state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"]
        olmoe_state_dict[f"transformer.blocks.{layer_i}.ffn.router.layer.weight"] = state_dict[f"m
8000
odel.layers.{layer_i}.mlp.gate.weight"]
        olmoe_state_dict[f"transformer.blocks.{layer_i}.attn_norm.weight"] = state_dict[f"model.layers.{layer_i}.input_layernorm.weight"]
        olmoe_state_dict[f"transformer.blocks.{layer_i}.ff_norm.weight"] = state_dict[f"model.layers.{layer_i}.post_attention_layernorm.weight"]

        
        num_experts = state_dict[f"model.layers.{layer_i}.mlp.gate.weight"].shape[0]
        # dim_per_expert = state_dict[f"model.layers.{layer_i}.mlp.experts.0.gate_proj.weight"].shape[1]
        print(num_experts)
        # Reconstruct MoE expert weights
        w1_list = [state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.gate_proj.weight"] for expert_i in range(num_experts)]
        v1_list = [state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.up_proj.weight"] for expert_i in range(num_experts)]
        w2_list = [state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_i}.down_proj.weight"].T for expert_i in range(num_experts)]

        olmoe_state_dict[f"transformer.blocks.{layer_i}.ffn.experts.mlp.w1"] = torch.cat(w1_list, dim=0)
        olmoe_state_dict[f"transformer.blocks.{layer_i}.ffn.experts.mlp.v1"] = torch.cat(v1_list, dim=0)
        olmoe_state_dict[f"transformer.blocks.{layer_i}.ffn.experts.mlp.w2"] = torch.cat(w2_list, dim=0)

        


    # ipdb.set_trace()
    dims_per_head = dim // n_heads
    print("here")
    print(n_heads)
    print(dim)
    print(dims_per_head)
    base = 10000.0
    # olmoe_state_dict["transformer.inv_freq"] = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
    olmoe_state_dict["transformer.wte.weight"] = state_dict["model.embed_tokens.weight"]
    olmoe_state_dict["transformer.ff_out.weight"] = state_dict["lm_head.weight"]
    olmoe_state_dict["transformer.ln_f.weight"] = state_dict["model.norm.weight"]
    
    # Save model.pt
    torch.save(olmoe_state_dict, os.path.join(model_path, "model.pt"))
    
    # Save config.yaml
    olmoe_config = {
        "model": {
            "n_layers": hf_config.num_hidden_layers,
            "n_heads": hf_config.num_attention_heads,
            "d_model": hf_config.hidden_size,
            "embedding_size": hf_config.vocab_size,
            "max_sequence_length": hf_config.max_position_embeddings,
            "pad_token_id": hf_config.pad_token_id,
            "eos_token_id": hf_config.eos_token_id,
            "weight_tying": hf_config.tie_word_embeddings,
            "block_type": "moe",
            "n_kv_heads" : None,
            "clip_qkv" : None,
            "include_bias" : False,
            "bias_for_layer_norm" : False       
             }
    }
    print(olmoe_config)
    with open(os.path.join(model_path, "config.yaml"), "w") as f:
        yaml.dump(olmoe_config, f)
    
    print(f"Converted model saved to {model_path}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--hf_model_path", required=True, help="Path to the HF model directory.")
    parser.add_argument("--output_dir", required=True, help="Directory to save OLMo checkpoint.")
    args = parser.parse_args()
    save_olmo_checkpoint(args.output_dir, args.hf_model_path)

if __name__ == "__main__":
    main()

@Muennighoff
Copy link
Collaborator

Tagging @aman-17 here 🙌

@aman-17
Copy link
Member
aman-17 commented May 12, 2025

@aditi184, did you solve this? The errors like missing keys such as transformer.wpe.weight and unexpected ones like q_norm.weight are expected if there’s a structural mismatch between the HF checkpoint and what the OLMo class expects. It’s likely due to architectural differences.

@aditi184
Copy link
Author

Yes I have written the conversion script for it

I can create a PR if you would like to include it in the code-base.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
0