8000 rename num_kv remove multiquery by vince62s · Pull Request #12 · eole-nlp/eole · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

rename num_kv remove multiquery #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,25 +337,19 @@ def run(cls, args):
mlp_activation_fn = act_table[arch]
layer_norm = ln_table[arch]

multiquery = False
if "multi_query" in config.keys():
multiquery = config["multi_query"]
num_kv = 1
if "multi_query" in config.keys() and config["multi_query"]:
heads_kv = 1 # might be usefull for old config
elif (
"num_key_value_heads" in config.keys()
and config["num_key_value_heads"] != heads
):
num_kv = config["num_key_value_heads"]
heads_kv = config["num_key_value_heads"]
elif "num_kv_heads" in config.keys() and config["num_kv_heads"] != heads:
num_kv = config["num_kv_heads"]
heads_kv = config["num_kv_heads"]
elif "n_head_kv" in config.keys() and config["n_head_kv"] != heads:
num_kv = config["n_head_kv"]
heads_kv = config["n_head_kv"]
else:
num_kv = 0
if num_kv is None:
num_kv = 0

shared_layer = num_kv == 1
heads_kv = heads

if "parallel_attn" in config.keys():
parallel_residual = config["parallel_attn"]
Expand Down Expand Up @@ -451,9 +445,11 @@ def run(cls, args):
add_qkvbias = False
add_ffnbias = False
rotary_interleave = False
shared_layer_norm = False

if arch == "PhiForCausalLM":
parallel_residual = True
shared_layer = True
shared_layer_norm = True
add_qkvbias = True
add_ffnbias = True
rotary_interleave = False
Expand Down Expand Up @@ -615,8 +611,6 @@ def get_weight(checkpoint, tensor_name):
+ param,
)

if num_kv == 0:
num_kv = heads
if w is not None:
if type(source) == tuple:
w = eval("w" + srcmap)
Expand All @@ -627,7 +621,7 @@ def get_weight(checkpoint, tensor_name):
+ param
] = w

if shared_layer:
if shared_layer_norm:
idx = 0
else:
idx = 1
Expand Down Expand Up @@ -857,10 +851,9 @@ def get_weight(checkpoint, tensor_name):
rotary_theta=rope_theta,
rotary_dim=rotary_dim,
sliding_window=sliding_window,
multiquery=multiquery,
num_kv=num_kv,
heads_kv=heads_kv,
parallel_residual=parallel_residual,
shared_layer_norm=shared_layer,
shared_layer_norm=shared_layer_norm,
add_qkvbias=add_qkvbias,
add_ffnbias=add_ffnbias,
num_experts=num_experts,
Expand Down
30 changes: 14 additions & 16 deletions eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

class EmbeddingsConfig(Config):
src_word_vec_size: int = Field(
default=500, description="Word embedding size for src."
default=512, description="Word embedding size for src."
)
tgt_word_vec_size: int = Field(
default=500, description="Word embedding size for tgt."
default=512, description="Word embedding size for tgt."
)
word_vec_size: int = Field(
default=-1, description="Word embedding size for src and tgt."
Expand Down Expand Up @@ -40,12 +40,12 @@ class EncoderConfig(Config):
default="rnn", description="Type of encoder layer(s) to use."
)
layers: int = Field(default=2, description="Number of layers in the encoder.")
hidden_size: int = Field(default=500, description="Size of encoder hidden states.")
hidden_size: int = Field(default=512, description="Size of encoder hidden states.")

# This field should be set at EmbeddingsConfig level but will be copied here for cases
# where input size to the rnn is different to the hidden size
src_word_vec_size: int = Field(
default=500, description="Word embedding size for src."
default=512, description="Word embedding size for src."
)


Expand All @@ -56,12 +56,12 @@ class DecoderConfig(Config):
default="rnn", description="Type of decoder layer(s) to use."
)
layers: int = Field(default=2, description="Number of layers in the decoder.")
hidden_size: int = Field(default=500, description="Size of decoder hidden states.")
hidden_size: int = Field(default=512, description="Size of decoder hidden states.")

# This field should be set at EmbeddingsConfig level but will be copied here for cases
# where input size to the rnn is different to the hidden size
tgt_word_vec_size: int = Field(
default=500, description="Word embedding size for tgt."
default=512, description="Word embedding size for tgt."
)
coverage_attn: bool = Field(
default=False, description="Train a coverage attention layer."
Expand Down Expand Up @@ -197,13 +197,9 @@ class TransformerConfig(Config):
description="Add bias to nn.Linear of Query/Key/Value in MHA. "
"Note: this will add bias to output projection layer too.",
)
multiquery: bool = Field(
default=False,
description="Use MultiQuery attention (https://arxiv.org/pdf/1911.02150.pdf)",
)
num_kv: int = Field(
default=0,
description="Number of heads for KV in the variant of MultiQuery attention "
heads_kv: int | None = Field(
default=None,
description="Number of heads for KV. heads_kv=heads 8000 if None, else number of heads for KV"
"(e.g. Falcon 40B)",
)
add_ffnbias: bool = Field(
Expand Down Expand Up @@ -277,9 +273,11 @@ def _validate_transformer_decoder_config(self):
# )
# )

# multiquery is mostly a decoder thing, but we should make this cleaner at some point
if self.multiquery and self.num_kv == 0:
self.num_kv = 1
assert (
self.hidden_size % self.heads == 0
), "Transformer Model dimension {} must be divisible by the number of heads {}".format(
self.hidden_size, self.heads
)
return self


Expand Down
62 changes: 23 additions & 39 deletions eole/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,49 +270,33 @@ def __init__(
is_decoder: bool = True,
attn_type: str = None,
) -> None:
assert (
model_config.hidden_size % model_config.heads == 0
), "Model dimension must be divisible by the number of heads"
self.dim_per_head = model_config.hidden_size // model_config.heads
super(MultiHeadedAttention, self).__init__()
self.heads = model_config.heads
self.num_kv = model_config.num_kv
self.heads_kv = (
model_config.heads_kv
if model_config.heads_kv is not None
else model_config.heads
)
self.parallel_gpu = running_config.parallel_gpu

if model_config.num_kv == 0:
assert (
model_config.hidden_size % self.parallel_gpu == 0
), "Model dimension must be divisible by the number of partitions"
self.linear_keys = skip_init(
nn.Linear,
in_features=model_config.hidden_size,
out_features=model_config.hidden_size // self.parallel_gpu,
bias=model_config.add_qkvbias,
)
self.linear_values = skip_init(
nn.Linear,
in_features=model_config.hidden_size,
out_features=model_config.hidden_size // self.parallel_gpu,
bias=model_config.add_qkvbias,
)
else:
assert (
self.dim_per_head * self.num_kv
) % self.parallel_gpu == 0, (
"Model dimension must be divisible by the number of partitions"
)
self.linear_keys = skip_init(
nn.Linear,
in_features=model_config.hidden_size,
out_features=self.dim_per_head * self.num_kv // self.parallel_gpu,
bias=model_config.add_qkvbias,
)
self.linear_values = skip_init(
nn.Linear,
in_features=model_config.hidden_size,
out_features=self.dim_per_head * self.num_kv // self.parallel_gpu,
bias=model_config.add_qkvbias,
)
assert (
self.dim_per_head * self.heads_kv
) % self.parallel_gpu == 0, (
"Model dimension must be divisible by the number of partitions"
)
self.linear_keys = skip_init(
nn.Linear,
in_features=model_config.hidden_size,
out_features=self.dim_per_head * self.heads_kv // self.parallel_gpu,
bias=model_config.add_qkvbias,
)
self.linear_values = skip_init(
nn.Linear,
in_features=model_config.hidden_size,
out_features=self.dim_per_head * self.heads_kv // self.parallel_gpu,
bias=model_config.add_qkvbias,
)
self.linear_query = skip_init(
nn.Linear,
in_features=model_config.hidden_size,
Expand Down Expand Up @@ -601,7 +585,7 @@ def forward(
)

b, h, l, d = key.size()
if self.num_kv > 0:
if self.heads_kv < self.heads:
qh = query.size(1)
# expand key on heads dimension when it's less than query heads (multi-query variant)
key = key.view(b, -1, 1, l, d).repeat(1, 1, qh // h, 1, 1)
Expand Down
Loading
0