From 8d1ad505b316bd8616acb23224e03680eacc9bcf Mon Sep 17 00:00:00 2001 From: fh Date: Wed, 12 Jun 2024 22:01:03 +0200 Subject: [PATCH] patch lora bin to dump json config --- eole/bin/model/lora_weights.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/eole/bin/model/lora_weights.py b/eole/bin/model/lora_weights.py index 803085866..174d320b9 100644 --- a/eole/bin/model/lora_weights.py +++ b/eole/bin/model/lora_weights.py @@ -3,6 +3,7 @@ from eole.models.model_saver import load_checkpoint from eole.models import get_model_class from eole.inputters.inputter import dict_to_vocabs, vocabs_to_dict +from eole.config import recursive_model_fields_set from safetensors import safe_open from safetensors.torch import save_file import glob @@ -101,8 +102,9 @@ def run(cls, args): with open(os.path.join(args.output, "vocab.json"), "w", encoding="utf-8") as f: json.dump(vocab_dict, f, indent=2, ensure_ascii=False) # save config + config_dict = recursive_model_fields_set(new_config) with open(os.path.join(args.output, "config.json"), "w", encoding="utf-8") as f: - json.dump(new_config, f, indent=2, ensure_ascii=False) + json.dump(config_dict, f, indent=2, ensure_ascii=False) shards = glob.glob(os.path.join(args.base_model, "model.*.safetensors")) f = [] for i, shard in enumerate(shards):