8000 fix: Handle missing and unexpected keys during LLMEncoder state dict load by jeffkinnison · Pull Request #3841 · ludwig-ai/ludwig · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix: Handle missing and unexpected keys during LLMEncoder state dict load #3841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions ludwig/encoders/text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,6 +2391,8 @@ class LLMEncoder(Encoder):

def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs):
super().__init__()
self.register_load_state_dict_post_hook(self.remove_missing_non_adapter_keys)

self.config = encoder_config

self.model_name = self.config.base_model
Expand Down Expand Up @@ -2511,6 +2513,7 @@ def _load_from_state_dict(
# Call this first to make sure torch can do its usual load. In the adapter case, this should essentially be a
# no-op, but the adapter weights will be collected in `unexpected_keys` because PEFT changes the parameter
# names under the hood.

super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
Expand All @@ -2525,8 +2528,37 @@ def _load_from_state_dict(
peft_model_state_dict = {k: v for k, v in state_dict.items() if adapter_type_prefix in k}
set_peft_model_state_dict(self.model, peft_model_state_dict)

if strict:
for k in peft_model_state_dict.keys():
sanitized = k.replace(f"{prefix}model.", "") # `unexpected_keys` doesn't record the prefix
sanitized = sanitized.replace("default.", "") # By default, PEFT adds a "default." to param names
unexpected_keys.remove(sanitized)
def remove_missing_non_adapter_keys(self, module, incompatible_keys):
"""Update the missing and unexpected keys lists to reflect custom adapter state load logic.

This method should never return anything unless the underlying torch hook logic is updated. Any changes to the
lists in `incompatible_keys` must be made in-place.

Args:
module: The torch modulewith newly loaded state
incompatible_keys: A tuple with the lists of missing and unexpected keys that were recorded while loading
"""
# If no adapter was used, `LLMEncoder.load_state_dict` should use the default `torch.Module.load_state_dict`
# code path to load weights and no modification should be necessary.
if self.config.adapter:
adapter_type_prefix = self.ADAPTER_PARAM_NAME_PREFIX[self.config.adapter.type]
missing_keys, unexpected_keys = incompatible_keys

# When loading the adapter weights in strict mode, torch will register the base model weights as missing
# from the state dict and raise an exception. The base model weights are intended to be excluded, so the
# missing_keys list is updated post-load to avoid the error.
for k, _ in self.named_parameters():
if k in missing_keys and adapter_type_prefix not in k:
missing_keys.remove(k)

# peft changes the adapter parameter names under the hood to include the adapter name. When retreiving the
# adapter state dict, however, the name is not included. This causes the adpater weights to be recorded as
# unexpected parameters. `LLMEncoder._load_from_state_dict` loads the adapter parameters using a peft
# utility that accounts for the updated names, so here we remove any adapter parameters from the unexpected
# keys list to avoid errors.
from peft.utils.save_and_load import get_peft_model_state_dict

sd = get_peft_model_state_dict(self.model)
for k in sd.keys():
if k in unexpected_keys:
unexpected_keys.remove(k)
48 changes: 24 additions & 24 deletions tests/ludwig/encoders/test_llm_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,35 +103,35 @@ def weights_init(m):
encoder2_sd = encoder2.state_dict()
assert all(map(lambda k: torch.equal(encoder1_sd[k], encoder2_sd[k]), encoder1_sd.keys()))

# def test_load_from_state_dict_adapter(self, encoder_config_with_adapter: LLMEncoderConfig):
# def weights_init(m):
# """Reinitialize the weights of a torch module."""
# if hasattr(m, "weight") and m.weight.ndim > 1:
# torch.nn.init.xavier_uniform_(m.weight.data)
def test_load_from_state_dict_adapter(self, encoder_config_with_adapter: LLMEncoderConfig):
def weights_init(m):
"""Reinitialize the weights of a torch module."""
if hasattr(m, "weight") and m.weight.ndim > 1:
torch.nn.init.xavier_uniform_(m.weight.data)

# # Create two encoders from the same config
# encoder1 = LLMEncoder(encoder_config=encoder_config_with_adapter)
# encoder2 = LLMEncoder(encoder_config=encoder_config_with_adapter)
# Create two encoders from the same config
encoder1 = LLMEncoder(encoder_config=encoder_config_with_adapter)
encoder2 = LLMEncoder(encoder_config=encoder_config_with_adapter)

# encoder2.apply(weights_init)
encoder2.apply(weights_init)

# encoder1_sd = encoder1.state_dict()
# encoder2_sd = encoder2.state_dict()
# adapter_keys = [k for k in encoder1_sd.keys() if "lora_" in k and "weight" in k]
# model_keys = [k for k in encoder1_sd.keys() if "lora_" not in k]
encoder1_sd = encoder1.state_dict()
encoder2_sd = encoder2.state_dict()
adapter_keys = [k for k in encoder1_sd.keys() if "lora_" in k and "weight" in k]
model_keys = [k for k in encoder1_sd.keys() if "lora_" not in k]

# # The LoRA weights should no longer be equal
# assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), adapter_keys))
# The LoRA weights should no longer be equal
assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), adapter_keys))

# # The remaining weights should also no longer be equal
# assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), model_keys))
# The remaining weights should also no longer be equal
assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), model_keys))

# # Load the weights of encoder1 back into encoder2
# encoder2.load_state_dict(encoder1_sd)
# encoder2_sd = encoder2.state_dict()
# Load the weights of encoder1 back into encoder2
encoder2.load_state_dict(encoder1_sd)
encoder2_sd = encoder2.state_dict()

# # The LoRA weights should now be equal again
# assert all(map(lambda k: torch.equal(encoder1_sd[k], encoder2_sd[k]), adapter_keys))
# The LoRA weights should now be equal again
assert all(map(lambda k: torch.equal(encoder1_sd[k], encoder2_sd[k]), adapter_keys))

# # The remaining weights should still be unequal
# assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), model_keys))
# The remaining weights should still be unequal
assert all(map(lambda k: not torch.equal(encoder1_sd[k], encoder2_sd[k]), model_keys))
0