-
You must be signed in to change notification settings -
[LLM] Add MTP for Deepseekv3 #9876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLM] Add MTP for Deepseekv3 #9876
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9876 +/- ##
===========================================
- Coverage 51.34% 51.28% -0.07%
===========================================
Files 745 745
Lines 118590 118778 +188
===========================================
+ Hits 60886 60910 +24
- Misses 57704 57868 +164 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@@ -682,13 +687,14 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs): | |||
dtype=paddle.get_default_dtype(), | |||
default_initializer=nn.initializer.Constant(0.0), | |||
) | |||
self.e_score_correction_bias.is_distributed = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个有梯度是吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有梯度,需要更新
k_pe = GatherOp.apply(k_pe) | ||
k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand( | ||
[-1, q_len, self.num_heads, self.qk_rope_head_dim] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是修复sp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sp 还没修复,部分代码在这里保留了
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope | ||
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe | ||
query_states = paddle.concat([q_nope, q_pe], axis=-1) | ||
key_states = paddle.concat([k_nope, k_pe], axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是换了实现?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考自动并行实现,结果是一致的
if self.sequence_parallel: | ||
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) | ||
inputs_embeds = ScatterOp.apply(inputs_embeds) | ||
return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可能有潜在 显存问题,现在是 input_emb mtp_emb 一路pp,向后发送吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好像确实存在显存问题,每一层都会有mtp_emb,这部分我再想一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里暂时没有更好的方式来计算,后续可以尝试直接输入完整的input_embed向后传,应该只多占用一个hidden_state
|
||
|
||
class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer): | ||
def forward(self, args): | ||
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) | ||
|
||
if self.config.num_nextn_predict_layers > 0: | ||
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1) | ||
inputs_embeds_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
很低效的,应该最后一层share weight的话,之前去最后一层取。最后一层看是不是可以根据label idx取
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
inputs_embeds_cur_depth = paddle.concat( | ||
[inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好是这里重新取embeding,显存消耗小一些。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
计算或显存占用,当前方式多占用显存为[batch_size, n, hidden_size], 其中n为MTP层数,这部分显存占用也还能接受。
hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) | ||
|
||
inputs_embeds_cur_depth = paddle.concat( | ||
[inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这concat的是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输入是【1,2,3,4,5】,embedding的是【1,2,3,4,5】,其中【1,2,3,4】和decoder架构一致forward,MTP layer处理【2,3,4,5】,此处的concat是用于拼接【2,3,4】和【5】
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | ||
return tuple( | ||
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mtp_outputs] if v is not None | ||
) | ||
return BaseModelOutputWithPast( | ||
last_hidden_state=hidden_states, | ||
past_key_values=next_cache, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也返回一个 mtp_outputs ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Before submitting
tests
folder. If there are codecov issues, please add tests cases first.PR types
New features
PR changes
Models
Description