- 多种模型:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- 集成方法:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- 多种精度:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- 先进算法:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- 实用技巧:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- 实验监控:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- 极速推理:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
模型名 | 模型大小 | Template |
---|---|---|
Baichuan 2 | 7B/13B | baichuan2 |
BLOOM/BLOOMZ | 560M/1.1B/1.7B/3B/7.1B/176B | - |
ChatGLM3 | 6B | chatglm3 |
Command R | 35B/104B | cohere |
DeepSeek (Code/MoE) | 7B/16B/67B/236B | deepseek |
Falcon | 7B/11B/40B/180B | falcon |
Gemma/Gemma 2/CodeGemma | 2B/7B/9B/27B | gemma |
GLM-4 | 9B | glm4 |
InternLM2 | 7B/20B | intern2 |
Llama | 7B/13B/33B/65B | - |
Llama 2 | 7B/13B/70B | llama2 |
Llama 3 | 8B/70B | llama3 |
LLaVA-1.5 | 7B/13B | vicuna |
Mistral/Mixtral | 7B/8x7B/8x22B | mistral |
OLMo | 1B/7B | - |
PaliGemma | 3B | gemma |
Phi-1.5/Phi-2 | 1.3B/2.7B | - |
Phi-3 | 4B/7B/14B | phi |
Qwen/Qwen1.5/Qwen2 (Code/MoE) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
StarCoder 2 | 3B/7B/15B | - |
XVERSE | 7B/13B/65B | xverse |
Yi/Yi-1.5 | 6B/9B/34B | yi |
Yi-VL | 6B/34B | yi_vl |
Yuan 2 | 2B/51B/102B | yuan |
Note
对于所有“基座”(Base)模型,template
参数可以是 default
, alpaca
, vicuna
等任意值。但“对话”(Instruct/Chat)模型请务必使用对应的模板。
请务必在训练和推理时采用完全一致的模板。
项目所支持模型的完整列表请参阅 constants.py。
您也可以在 template.py 中添加自己的对话模板。
预训练数据集
指令微调数据集
- Identity (en&zh)
- Stanford Alpaca (en)
- Stanford Alpaca (zh)
- Alpaca GPT4 (en&zh)
- Glaive Function Calling V2 (en&zh)
- LIMA (en)
- Guanaco Dataset (multilingual)
- BELLE 2M (zh)
- BELLE 1M (zh)
- BELLE 0.5M (zh)
- BELLE Dialogue 0.4M (zh)
- BELLE School Math 0.25M (zh)
- BELLE Multiturn Chat 0.8M (zh)
- UltraChat (en)
- OpenPlatypus (en)
- CodeAlpaca 20k (en)
- Alpaca CoT (multilingual)
- OpenOrca (en)
- SlimOrca (en)
- MathInstruct (en)
- Firefly 1.1M (zh)
- Wiki QA (en)
- Web QA (zh)
- WebNovel (zh)
- Nectar (en)
- deepctrl (en&zh)
- Advertise Generating (zh)
- ShareGPT Hyperfiltered (en)
- ShareGPT4 (en&zh)
- UltraChat 200k (en)
- AgentInstruct (en)
- LMSYS Chat 1M (en)
- Evol Instruct V2 (en)
- Cosmopedia (en)
- STEM (zh)
- Ruozhiba (zh)
- Neo-sft (zh)
- WebInstructSub (en)
- Magpie-Pro-300K-Filtered (en)
- LLaVA mixed (en&zh)
- Open Assistant (de)
- Dolly 15k (de)
- Alpaca GPT4 (de)
- OpenSchnabeltier (de)
- Evol Instruct (de)
- Dolphin (de)
- Booksum (de)
- Airoboros (de)
- Ultrachat (de)
偏好数据集
必需项 | 至少 | 推荐 |
---|---|---|
python | 3.8 | 3.11 |
torch | 1.13.1 | 2.3.0 |
transformers | 4.41.2 | 4.41.2 |
datasets | 2.16.0 | 2.19.2 |
accelerate | 0.30.1 | 0.30.1 |
peft | 0.11.1 | 0.11.1 |
trl | 0.8.6 | 0.9.4 |
可选项 | 至少 | 推荐 |
---|---|---|
CUDA | 11.6 | 12.2 |
deepspeed | 0.10.0 | 0.14.0 |
bitsandbytes | 0.39.0 | 0.43.1 |
vllm | 0.4.3 | 0.4.3 |
flash-attn | 2.3.0 | 2.5.9 |
* 估算值
方法 | 精度 | 7B | 13B | 30B | 70B | 110B | 8x7B | 8x22B |
---|---|---|---|---|---|---|---|---|
Full | AMP | 120GB | 240GB | 600GB | 1200GB | 2000GB | 900GB | 2400GB |
Full | 16 | 60GB | 120GB | 300GB | 600GB | 900GB | 400GB | 1200GB |
Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 360GB | 160GB | 400GB |
LoRA/GaLore/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | 240GB | 120GB | 320GB |
QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 140GB | 60GB | 160GB |
QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 72GB | 30GB | 96GB |
QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 48GB | 18GB | 48GB |
- 指令监督微调数据集,数据需要放在data目录下,以Alpaca 格式
- 样例数据集在指令监督微调时,
instruction
列对应的内容会与input
列对应的内容拼接后作为人类指令,即人类指令为instruction\ninput
。而output
列对应的内容为模型回答。
如果指定,system
列对应的内容将被作为系统提示词。
history
列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容也会被用于模型学习。
// 数据格式样例
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
["第一轮指令(选填)", "第一轮回答(选填)"],
["第二轮指令(选填)", "第二轮回答(选填)"]
]
}
]
- 在dataset_info.json中将构造数据字段进行描述,
对于上述格式的数据,
dataset_info.json
中的数据集描述应为:
"数据集名称": {
"file_name": "data.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"system": "system",
"history": "history"
}
}
在examples目录下创建训练配置文件,例如sft配置文件
- 指定验证集:在配置文件中指定val_dataset参数
### dataset
dataset: wedoctor-0711
val_dataset: test-eval
template: gemma
# max_padding_len
cutoff_len: 1024
# max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
- 指定验证方式:在metrics中自定义eval_accuracy的计算方式
# 自定义计算逻辑
def compute_metrics_for_label(pred_str, label_str):
match = re.search(r'\b([ABCDE])\b', pred_str)
if match:
predicted_answer = match.group(1)
return predicted_answer == label_str
return False
# 指定计算函数
COMPUTE_METHOD = {"func": compute_metrics_for_label}
使用shell脚本开启训练
sh run.sh
gemma训练:使用run_gemma.sh脚本
sh run_gemma.sh