8000 [LLM] Add MTP for Deepseekv3 by DrownFish19 · Pull Request #9876 · PaddlePaddle/PaddleNLP · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[LLM] Add MTP for Deepseekv3 #9876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 25, 2025

Conversation

DrownFish19
Copy link
Collaborator

Before submitting

  • Lint code. If there are lint issues, please format the code first.
# Install and register `pre-commit` in the project folder
pip install pre-commit && pre-commit install

# Process previous code files separately
pre-commit run --file XXXX.py
  • Add test cases into tests folder. If there are codecov issues, please add tests cases first.

PR types

New features

PR changes

Models

Description

  1. Add MTP for Deepseekv3.

Copy link
paddle-bot bot commented Feb 17, 2025

Thanks for your contribution!

Copy link
codecov bot commented Feb 17, 2025

Codecov Report

Attention: Patch coverage is 12.71186% with 206 lines in your changes missing coverage. Please review.

Project coverage is 51.28%. Comparing base (30df8b6) to head (f1676df).
Report is 338 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/transformers/deepseek_v2/modeling.py 3.57% 108 Missing ⚠️
paddlenlp/transformers/deepseek_v2/modeling_pp.py 7.69% 84 Missing ⚠️
paddlenlp/transformers/moe_gate.py 42.85% 8 Missing ⚠️
paddlenlp/transformers/deepseek_v3/modeling.py 0.00% 3 Missing ⚠️
paddlenlp/transformers/moe_layer.py 71.42% 2 Missing ⚠️
...addlenlp/transformers/deepseek_v2/configuration.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9876      +/-   ##
===========================================
- Coverage    51.34%   51.28%   -0.07%     
===========================================
  Files          745      745              
  Lines       118590   118778     +188     
===========================================
+ Hits         60886    60910      +24     
- Misses       57704    57868     +164     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@@ -682,13 +687,14 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(0.0),
)
self.e_score_correction_bias.is_distributed = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个有梯度是吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有梯度,需要更新

k_pe = GatherOp.apply(k_pe)
k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand(
[-1, q_len, self.num_heads, self.qk_rope_head_dim]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是修复sp?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sp 还没修复,部分代码在这里保留了

key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
query_states = paddle.concat([q_nope, q_pe], axis=-1)
key_states = paddle.concat([k_nope, k_pe], axis=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是换了实现?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考自动并行实现,结果是一致的

if self.sequence_parallel:
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]])
inputs_embeds = ScatterOp.apply(inputs_embeds)
return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能有潜在 显存问题,现在是 input_emb mtp_emb 一路pp,向后发送吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像确实存在显存问题,每一层都会有mtp_emb,这部分我再想一下

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里暂时没有更好的方式来计算,后续可以尝试直接输入完整的input_embed向后传,应该只多占用一个hidden_state



class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer):
def forward(self, args):
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)

if self.config.num_nextn_predict_layers > 0:
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1)
inputs_embeds_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

很低效的,应该最后一层share weight的话,之前去最后一层取。最后一层看是不是可以根据label idx取

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


inputs_embeds_cur_depth = paddle.concat(
[inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好是这里重新取embeding,显存消耗小一些。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

计算或显存占用,当前方式多占用显存为[batch_size, n, hidden_size], 其中n为MTP层数,这部分显存占用也还能接受。

hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]])

inputs_embeds_cur_depth = paddle.concat(
[inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这concat的是?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入是【1,2,3,4,5】,embedding的是【1,2,3,4,5】,其中【1,2,3,4】和decoder架构一致forward,MTP layer处理【2,3,4,5】,此处的concat是用于拼接【2,3,4】和【5】

return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mtp_outputs] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也返回一个 mtp_outputs ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@ZHUI ZHUI merged commit 7acaf18 into PaddlePaddle:develop Feb 25, 2025
8 of 12 checks passed
@DrownFish19 DrownFish19 deleted the dev_20250214_deepseek_mtp branch February 25, 2025 05:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0