From ed1c74d6a0ff41efd0738eb3285765e9fc890cda Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 7 Jan 2025 20:11:10 +0100 Subject: [PATCH 01/11] refactor transformer decoder --- .github/workflows/push.yml | 4 +- eole/config/models.py | 19 +- eole/decoders/__init__.py | 4 +- eole/decoders/cnn_decoder.py | 4 +- eole/decoders/decoder.py | 2 +- eole/decoders/rnn_decoder.py | 6 +- eole/decoders/transformer.py | 356 ++++++++++++++++++++++++ eole/decoders/transformer_base.py | 172 ------------ eole/decoders/transformer_decoder.py | 244 ---------------- eole/decoders/transformer_lm_decoder.py | 173 ------------ eole/models/model.py | 14 +- eole/tests/pull_request_check.sh | 6 +- eole/tests/test_model_lm.yml | 4 +- eole/tests/test_model_lm/config.json | 2 +- 14 files changed, 387 insertions(+), 623 deletions(-) create mode 100644 eole/decoders/transformer.py delete mode 100644 eole/decoders/transformer_base.py delete mode 100644 eole/decoders/transformer_decoder.py delete mode 100644 eole/decoders/transformer_lm_decoder.py diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 1e1d7f7c2..f26c626c8 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -186,7 +186,7 @@ jobs: -config eole/tests/data/lm_data.yaml \ -src_vocab /tmp/eole.vocab.src \ -tgt_vocab /tmp/eole.vocab.src \ - -model '{"hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "layers": 2, "heads": 4}}' \ + -model '{"hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer", "layers": 2, "heads": 4}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10}' \ -src_vocab_size 1000 \ -tgt_vocab_size 1000 \ @@ -357,7 +357,7 @@ jobs: -config eole/tests/data/lm_data.yaml \ -src_vocab /tmp/eole.vocab.src \ -tgt_vocab /tmp/eole.vocab.src \ - -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "heads": 4}}' \ + -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer", "heads": 4}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "model_path": "/tmp/lm.eole.model", "save_checkpoint_steps": 10}' \ -src_vocab_size 1000 \ -tgt_vocab_size 1000 \ diff --git a/eole/config/models.py b/eole/config/models.py index 5db288e4c..6a28f5fbb 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -366,15 +366,6 @@ def _validate_transformer_decoder_config(self): return self -class TransformerLMDecoderConfig(TransformerDecoderConfig): - """ - Right now just wraps TransformerDecoderConfig for simplicity. - Might merge in a single class later once TransformerLM path is clarified. - """ - - decoder_type: Literal["transformer_lm"] = Field(default="transformer_lm") - - # use Field with default= + description would be more readable # was inheriting from VocabConfig, but removed for now to facilitate inference tests # could we have different BaseModelConfig classes (inheriting from a base one) @@ -397,7 +388,6 @@ class BaseModelConfig(Config): ) # we shall use discriminators here decoder: Union[ TransformerDecoderConfig, - TransformerLMDecoderConfig, RnnDecoderConfig, CnnDecoderConfig, ] | None = Field( @@ -551,10 +541,7 @@ def update_model_opts(self): update_dict["decoder"] = { "tgt_word_vec_size": self.embeddings.tgt_word_vec_size } - if getattr(self.decoder, "decoder_type", None) in [ - "transformer", - "transformer_lm", - ]: + if getattr(self.decoder, "decoder_type", None) == "transformer": update_dict["decoder"].update( { "position_encoding_type": self.embeddings.position_encoding_type, @@ -737,9 +724,9 @@ def encoder_decoder_type(cls, data: Any) -> Any: if not (isinstance(data, dict)): return data if "decoder" in data.keys(): - data["decoder"]["decoder_type"] = "transformer_lm" + data["decoder"]["decoder_type"] = "transformer" else: - data["decoder"] = {"decoder_type": "transformer_lm"} + data["decoder"] = {"decoder_type": "transformer"} return data @model_validator(mode="before") diff --git a/eole/decoders/__init__.py b/eole/decoders/__init__.py index f7bb79332..4d46fb93d 100644 --- a/eole/decoders/__init__.py +++ b/eole/decoders/__init__.py @@ -1,7 +1,6 @@ """Module defining decoders.""" from eole.decoders.rnn_decoder import InputFeedRNNDecoder, StdRNNDecoder -from eole.decoders.transformer_decoder import TransformerDecoder -from eole.decoders.transformer_lm_decoder import TransformerLMDecoder +from eole.decoders.transformer import TransformerDecoder from eole.decoders.cnn_decoder import CNNDecoder @@ -10,5 +9,4 @@ "ifrnn": InputFeedRNNDecoder, "cnn": CNNDecoder, "transformer": TransformerDecoder, - "transformer_lm": TransformerLMDecoder, } diff --git a/eole/decoders/cnn_decoder.py b/eole/decoders/cnn_decoder.py index aa1862374..decad2a43 100644 --- a/eole/decoders/cnn_decoder.py +++ b/eole/decoders/cnn_decoder.py @@ -22,6 +22,7 @@ def __init__( self, model_config, running_config=None, + with_cross_attn=False, ): super(CNNDecoder, self).__init__() @@ -51,11 +52,12 @@ def __init__( ) @classmethod - def from_config(cls, model_config, running_config=None): + def from_config(cls, model_config, running_config=None, with_cross_attn=False): """Alternate constructor.""" return cls( model_config, running_config, + with_cross_attn=False, ) def init_state(self, **kwargs): diff --git a/eole/decoders/decoder.py b/eole/decoders/decoder.py index 79f9b961f..c4383f7f1 100644 --- a/eole/decoders/decoder.py +++ b/eole/decoders/decoder.py @@ -15,7 +15,7 @@ def __init__(self, attentional=True): self.state = {} @classmethod - def from_config(cls, model_config, running_config=None): + def from_config(cls, model_config, running_config=None, with_cross_attn=False): """Alternate constructor. Subclasses should override this method. diff --git a/eole/decoders/rnn_decoder.py b/eole/decoders/rnn_decoder.py index ac0bdcfc2..36abe4fc9 100644 --- a/eole/decoders/rnn_decoder.py +++ b/eole/decoders/rnn_decoder.py @@ -22,6 +22,7 @@ def __init__( self, model_config, running_config=None, + with_cross_attn=False, ): super(RNNDecoderBase, self).__init__( attentional=model_config.global_attention != "none" @@ -67,12 +68,13 @@ def __init__( ) @classmethod - def from_config(cls, model_config, running_config=None): + def from_config(cls, model_config, running_config=None, with_cross_attn=False): """Alternate constructor.""" # config = opt.model.decoder # RnnDecoderConfig return cls( model_config, running_config=running_config, + with_cross_attn=False, ) def init_state(self, **kwargs): @@ -189,6 +191,7 @@ def __init__( self, model_config, running_config=None, + with_cross_attn=False, ): self.hidden_size = model_config.hidden_size self._input_size = model_config.tgt_word_vec_size @@ -268,6 +271,7 @@ def __init__( self, model_config, running_config=None, + with_cross_attn=False, ): self.hidden_size = model_config.hidden_size self._input_size = model_config.tgt_word_vec_size + self.hidden_size diff --git a/eole/decoders/transformer.py b/eole/decoders/transformer.py new file mode 100644 index 000000000..ab3bce8f8 --- /dev/null +++ b/eole/decoders/transformer.py @@ -0,0 +1,356 @@ +""" +Implementation of "Attention is All You Need" and of +subsequent transformer based architectures +""" + +import torch +import torch.nn as nn +from eole.decoders.decoder import DecoderBase +from eole.modules.multi_headed_attn import SelfMHA, ContextMHA +from eole.modules.transformer_mlp import MLP +from eole.modules.moe import MoE +from eole.constants import LayerNorm + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + model_config, + running_config=None, + with_cross_attn=False, + ): + """ + Args: + model_config (eole.config.TransformerDecoderConfig): full decoder config + """ + super(TransformerDecoderLayer, self).__init__() + + self.parallel_residual = model_config.parallel_residual + self.shared_layer_norm = model_config.shared_layer_norm + self.dropout_p = getattr(running_config, "dropout", [0.0])[0] + self.full_context_alignment = model_config.full_context_alignment + self.alignment_heads = model_config.alignment_heads + + self.input_layernorm = LayerNorm[model_config.layer_norm]( + model_config.hidden_size, eps=model_config.norm_eps + ) + self.self_attn = SelfMHA( + model_config, + running_config=running_config, + ) + self.dropout = nn.Dropout(self.dropout_p) + + if with_cross_attn: + self.precontext_layernorm = LayerNorm[model_config.layer_norm]( + model_config.hidden_size, eps=model_config.norm_eps + ) + self.context_attn = ContextMHA( + model_config, + running_config=running_config, + ) + else: + self.context_attn = None + + if model_config.parallel_residual and not model_config.shared_layer_norm: + self.residual_layernorm = LayerNorm[model_config.layer_norm]( + model_config.hidden_size, eps=model_config.norm_eps + ) + + self.post_attention_layernorm = LayerNorm[model_config.layer_norm]( + model_config.hidden_size, eps=model_config.norm_eps + ) + + if model_config.num_experts > 0: + self.mlp = MoE(model_config, running_config) + else: + self.mlp = MLP( + model_config, + running_config=running_config, + ) + + def forward(self, layer_in, **kwargs): + """Extend `_forward` for (possibly) multiple decoder pass: + Always a default (future masked) decoder forward pass, + Possibly a second future aware decoder pass for joint learn + full context alignement, :cite:`garg2019jointly`. + + Args: + + Returns: + (FloatTensor, FloatTensor, FloatTensor or None): + + * layer_out ``(batch_size, T, model_dim)`` + * attns ``(batch_size, heads, T, src_len)`` + """ + enc_out = kwargs.pop("enc_out", None) + src_pad_mask = kwargs.pop("src_pad_mask", None) + attn_mask = kwargs.pop("attn_mask", None) + step = kwargs.pop("step", None) + return_attn = kwargs.pop("return_attn", False) + position_embeddings = kwargs.pop("position_embeddings", None) + + norm_layer_in = self.input_layernorm(layer_in) + + self_attn, attns = self.self_attn( + norm_layer_in, + attn_mask=attn_mask, + step=step, + return_attn=return_attn, + position_embeddings=position_embeddings, + ) + + if self.dropout_p > 0: + self_attn = self.dropout(self_attn) + + if self.parallel_residual: + if self.context_attn: + ctx_attn, attns = self.context_attn( + enc_out, + enc_out, + norm_layer_in, + attn_mask=~src_pad_mask, + return_attn=return_attn, + ) + else: + ctx_attn = 0 + if not self.shared_layer_norm: + norm_res_layer_in = self.residual_layernorm(layer_in) + ff_in = norm_res_layer_in + else: + ff_in = norm_layer_in + else: + if self.context_attn: + norm_query = self.precontext_layernorm(self_attn + layer_in) + ctx_attn, attns = self.context_attn( + enc_out, + enc_out, + norm_query, + attn_mask=~src_pad_mask, + return_attn=return_attn, + ) + if self.dropout_p > 0: + ctx_attn = self.dropout(ctx_attn) + else: + ctx_attn = 0 + ff_in = self.post_attention_layernorm(ctx_attn + self_attn + layer_in) + # we apply residual with un-normed + layer_out = self.mlp(ff_in) + layer_in + self_attn + ctx_attn + + return layer_out, attns + + def get_attn_align(self, layer_in, **kwargs): + attns = kwargs.pop("attns", None) + if self.full_context_alignment: + # return _, (B, Q_len, K_len) + _, attns = self.forward(layer_in, **kwargs) + if self.alignment_heads > 0: + attns = attns[:, : self.alignment_heads, :, :].contiguous() + # layer average attention across heads, get ``(B, Q, K)`` + # Case 1: no full_context, no align heads -> layer avg baseline + # Case 2: no full_context, 1 align heads -> guided align + # Case 3: full_context, 1 align heads -> full cte guided align + attn_align = attns.mean(dim=1) + + return attn_align + + def update_dropout(self, dropout, attention_dropout): + self.self_attn.update_dropout(attention_dropout) + if self.context_attn: + self.context_attn.update_dropout(attention_dropout) + self.mlp.update_dropout(dropout) + self.dropout.p = dropout + + +class TransformerDecoder(DecoderBase): + def __init__( + self, + model_config, + running_config=None, + with_cross_attn=False, + ): + super(TransformerDecoder, self).__init__() + self.transformer_layers = nn.ModuleList( + [ + TransformerDecoderLayer( + model_config, + running_config=running_config, + with_cross_attn=with_cross_attn, + ) + for i in range(model_config.layers) + ] + ) + # This is the Decoder out layer norm + self.layer_norm = LayerNorm[model_config.layer_norm]( + model_config.hidden_size, eps=model_config.norm_eps + ) + self._disable_cache() + self.alignment_layer = model_config.alignment_layer + self.with_cross_attn = with_cross_attn + self.sliding_window = model_config.sliding_window + + @classmethod + def from_config(cls, model_config, running_config=None, with_cross_attn=False): + """Alternate constructor.""" + return cls( + model_config, + running_config=running_config, + with_cross_attn=with_cross_attn, + ) + + def init_state(self, **kwargs): + """Initialize decoder state.""" + pass + + def map_state(self, fn): + # if self.state["src"] is not None: + # self.state["src"] = fn(self.state["src"], 0) + for layer in self.transformer_layers: + if self.with_cross_attn: + if layer.context_attn.layer_cache[1]["keys"].numel() != 0: + x = fn(layer.context_attn.layer_cache[1]["keys"], 0) + y = fn(layer.context_attn.layer_cache[1]["values"], 0) + layer.context_attn.layer_cache = True, {"keys": x, "values": y} + if layer.self_attn.layer_cache[1]["keys"].numel() != 0: + x = fn(layer.self_attn.layer_cache[1]["keys"], 0) + y = fn(layer.self_attn.layer_cache[1]["values"], 0) + if layer.self_attn.layer_cache[1].get("key_pad_mask", None) is not None: + z = fn(layer.self_attn.layer_cache[1]["key_pad_mask"], 0) + else: + z = None + layer.self_attn.layer_cache = True, { + "keys": x, + "values": y, + "key_pad_mask": z, + } + + def detach_state(self): + raise NotImplementedError + + def update_dropout(self, dropout, attention_dropout): + for layer in self.transformer_layers: + layer.update_dropout(dropout, attention_dropout) + + def _compute_attn_mask(self, tgt_pad_mask): + tgt_len = tgt_pad_mask.size(-1) + # Add triangular future_mask and pad_mask, result mask in (B, T, T). + future_mask = torch.tril( + torch.ones( + (tgt_len, tgt_len), device=tgt_pad_mask.device, dtype=torch.bool + ), + diagonal=0, + ) + if self.sliding_window > 0: + future_mask = future_mask.triu_(-self.sliding_window) + attn_mask = ~tgt_pad_mask & future_mask.unsqueeze(0) + return attn_mask + + def forward(self, emb, **kwargs): + """Decode, possibly stepwise.""" + + assert emb.dim() == 3 # batch x len x embedding_dim + enc_out = kwargs.pop("enc_out", None) + + tgt_pad_mask = kwargs.pop("tgt_pad_mask", None) + assert tgt_pad_mask is not None, "TransformerDecoder requires a tgt pad mask" + src_pad_mask = kwargs.pop("src_pad_mask", None) + if self.with_cross_attn: + assert ( + src_pad_mask is not None + ), "TransformerDecoder requires a src pad mask" + step = kwargs.pop("step", None) + with_align = kwargs.pop("with_align", False) + return_attn = with_align or kwargs.pop("return_attn", False) + position_embeddings = kwargs.pop("position_embeddings", None) + attn_aligns = [] + + if emb.size(1) > 1: + # masking is necessary when sequence length is greater than one + attn_mask = self._compute_attn_mask(tgt_pad_mask) + attn_mask = attn_mask.unsqueeze(1) + attn_mask = attn_mask.expand(-1, -1, attn_mask.size(3), -1) + if self.with_cross_attn: + src_pad_mask = src_pad_mask.unsqueeze(1).expand( + -1, -1, attn_mask.size(3), -1 + ) + # mask now are (batch x 1 x tlen x s or t len) + # 1 = heads to be expanded in MHA + else: + attn_mask = None + if self.with_cross_attn: + src_pad_mask = src_pad_mask.unsqueeze(1) + + if step == 0: + self._enable_cache(emb.device, tgt_pad_mask) + + for layer in self.transformer_layers: + emb, attn = layer( + emb, + enc_out=enc_out if enc_out is not None else emb, + src_pad_mask=src_pad_mask, + attn_mask=attn_mask, + step=step, + return_attn=return_attn, + position_embeddings=position_embeddings, + ) + if with_align: + attn_align = layer.get_attn_align( + emb, + enc_out=enc_out if enc_out is not None else emb, + src_pad_mask=src_pad_mask, + attn_mask=~tgt_pad_mask, + step=step, + return_attn=return_attn, + position_embeddings=position_embeddings, + attns=attn, + ) + if attn_align is not None: + attn_aligns.append(attn_align) + + emb = self.layer_norm(emb) + + # we take the first head + top_attn = None if attn is None else attn[:, 0, :, :].contiguous() + attns = {"std": top_attn} + if with_align: + attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` + # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg + + # TODO change the way attns is returned dict => list or tuple (onnx) + return emb, attns + + def _enable_cache(self, device, pad_mask): + for layer in self.transformer_layers: + # first value set to True triggered by the beginning of decoding + # layer_cache becomes active in the MultiHeadedAttention fwd + if layer.context_attn: + layer.context_attn.layer_cache = ( + True, + { + "keys": torch.tensor([], device=device), + "values": torch.tensor([], device=device), + }, + ) + layer.self_attn.layer_cache = ( + True, + { + "keys": torch.tensor([], device=device), + "values": torch.tensor([], device=device), + "key_pad_mask": pad_mask, + }, + ) + + def _disable_cache(self): + for layer in self.transformer_layers: + layer.self_attn.layer_cache = ( + False, + { + "keys": torch.tensor([]), + "values": torch.tensor([]), + "key_pad_mask": None, + }, + ) + if layer.context_attn: + layer.context_attn.layer_cache = ( + False, + {"keys": torch.tensor([]), "values": torch.tensor([])}, + ) diff --git a/eole/decoders/transformer_base.py b/eole/decoders/transformer_base.py deleted file mode 100644 index 03bebc529..000000000 --- a/eole/decoders/transformer_base.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -Implementation of "Attention is All You Need" and of -subsequent transformer based architectures -""" - -import torch -import torch.nn as nn -from eole.decoders.decoder import DecoderBase -from eole.modules.multi_headed_attn import SelfMHA -from eole.modules.transformer_mlp import MLP -from eole.modules.moe import MoE -from eole.constants import LayerNorm - - -class TransformerDecoderLayerBase(nn.Module): - def __init__( - self, - model_config, - running_config=None, - ): - """ - Args: - model_config (eole.config.TransformerDecoderConfig): full decoder config - """ - super(TransformerDecoderLayerBase, self).__init__() - - self.parallel_residual = model_config.parallel_residual - self.shared_layer_norm = model_config.shared_layer_norm - self.dropout_p = getattr(running_config, "dropout", [0.0])[0] - self.full_context_alignment = model_config.full_context_alignment - self.alignment_heads = model_config.alignment_heads - self.sliding_window = model_config.sliding_window - - self.input_layernorm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) - self.self_attn = SelfMHA( - model_config, - running_config=running_config, - ) - self.dropout = nn.Dropout(self.dropout_p) - self.post_attention_layernorm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) - if model_config.parallel_residual and not model_config.shared_layer_norm: - self.residual_layernorm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) - if model_config.num_experts > 0: - self.mlp = MoE(model_config, running_config) - else: - self.mlp = MLP( - model_config, - running_config=running_config, - ) - - def forward(self, *args, **kwargs): - """Extend `_forward` for (possibly) multiple decoder pass: - Always a default (future masked) decoder forward pass, - Possibly a second future aware decoder pass for joint learn - full context alignement, :cite:`garg2019jointly`. - - Args: - * All arguments of _forward, of which - with_align (bool): needed to compute attn_align - return_attn (bool): to force MHA to return attns - - Returns: - (FloatTensor, FloatTensor, FloatTensor or None): - - * layer_out ``(batch_size, T, model_dim)`` - * top_attn ``(batch_size, T, src_len)`` - * attn_align ``(batch_size, T, src_len)`` or None - """ - with_align = kwargs.pop("with_align", False) - layer_out, attns = self._forward(*args, **kwargs) - top_attn = None if attns is None else attns[:, 0, :, :].contiguous() - attn_align = None - if with_align: - if self.full_context_alignment: - # return _, (B, Q_len, K_len) - _, attns = self._forward(*args, **kwargs, future=True) - - if self.alignment_heads > 0: - attns = attns[:, : self.alignment_heads, :, :].contiguous() - # layer average attention across heads, get ``(B, Q, K)`` - # Case 1: no full_context, no align heads -> layer avg baseline - # Case 2: no full_context, 1 align heads -> guided align - # Case 3: full_context, 1 align heads -> full cte guided align - attn_align = attns.mean(dim=1) - return layer_out, top_attn, attn_align - - def update_dropout(self, dropout, attention_dropout): - self.self_attn.update_dropout(attention_dropout) - self.mlp.update_dropout(dropout) - self.dropout.p = dropout - - def _forward(self, *args, **kwargs): - raise NotImplementedError - - def _compute_attn_mask(self, tgt_pad_mask, future): - tgt_len = tgt_pad_mask.size(-1) - if not future: - # Add triangular future_mask and pad_mask, result mask in (B, T, T). - future_mask = torch.tril( - torch.ones( - (tgt_len, tgt_len), device=tgt_pad_mask.device, dtype=torch.bool - ), - diagonal=0, - ) - if self.sliding_window > 0: - future_mask = future_mask.triu_(-self.sliding_window) - attn_mask = ~tgt_pad_mask & future_mask.unsqueeze(0) - else: - # Only mask padding, result mask in (B, 1, T). - attn_mask = ~tgt_pad_mask - return attn_mask - - -class TransformerDecoderBase(DecoderBase): - def __init__( - self, - model_config, - running_config=None, - ): - super(TransformerDecoderBase, self).__init__() - - self.alignment_layer = model_config.alignment_layer - - @classmethod - def from_config(cls, model_config, running_config=None): - """Alternate constructor.""" - return cls( - model_config, - running_config=running_config, - ) - - def init_state(self, **kwargs): - """Initialize decoder state.""" - pass - - def map_state(self, fn): - # if self.state["src"] is not None: - # self.state["src"] = fn(self.state["src"], 0) - for layer in self.transformer_layers: - if hasattr(layer, "context_attn"): - if layer.context_attn.layer_cache[1]["keys"].numel() != 0: - x = fn(layer.context_attn.layer_cache[1]["keys"], 0) - y = fn(layer.context_attn.layer_cache[1]["values"], 0) - layer.context_attn.layer_cache = True, {"keys": x, "values": y} - if layer.self_attn.layer_cache[1]["keys"].numel() != 0: - x = fn(layer.self_attn.layer_cache[1]["keys"], 0) - y = fn(layer.self_attn.layer_cache[1]["values"], 0) - if layer.self_attn.layer_cache[1].get("key_pad_mask", None) is not None: - z = fn(layer.self_attn.layer_cache[1]["key_pad_mask"], 0) - else: - z = None - layer.self_attn.layer_cache = True, { - "keys": x, - "values": y, - "key_pad_mask": z, - } - - def detach_state(self): - raise NotImplementedError - - def forward(self, *args, **kwargs): - raise NotImplementedError - - def update_dropout(self, dropout, attention_dropout): - for layer in self.transformer_layers: - layer.update_dropout(dropout, attention_dropout) diff --git a/eole/decoders/transformer_decoder.py b/eole/decoders/transformer_decoder.py deleted file mode 100644 index 42f8fddfb..000000000 --- a/eole/decoders/transformer_decoder.py +++ /dev/null @@ -1,244 +0,0 @@ -""" -Implementation of "Attention is All You Need" and of -subsequent transformer based architectures -""" - -import torch -import torch.nn as nn -from eole.decoders.transformer_base import ( - TransformerDecoderLayerBase, - TransformerDecoderBase, -) -from eole.modules.multi_headed_attn import ContextMHA -from eole.constants import LayerNorm - - -class TransformerDecoderLayer(TransformerDecoderLayerBase): - """Transformer Decoder layer block in Pre-Norm style. - Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style, - providing better converge speed and performance. This is also the actual - implementation in tensor2tensor and also avalable in fairseq. - See https://tunz.kr/post/4 and :cite:`DeeperTransformer`. - - """ - - def __init__( - self, - model_config, - running_config=None, - ): - """ - Args: - See TransformerDecoderLayerBase - """ - super(TransformerDecoderLayer, self).__init__( - model_config, - running_config=running_config, - ) - - self.precontext_layernorm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) - self.context_attn = ContextMHA( - model_config, - running_config=running_config, - ) - - def update_dropout(self, dropout, attention_dropout): - super(TransformerDecoderLayer, self).update_dropout(dropout, attention_dropout) - self.context_attn.update_dropout(attention_dropout) - - def _forward( - self, - layer_in, - enc_out, - src_pad_mask, - tgt_pad_mask, - step=None, - future=False, - return_attn=False, - position_embeddings=None, - ): - """A naive forward pass for transformer decoder. - - # T: could be 1 in the case of stepwise decoding or tgt_len - - Args: - layer_in (FloatTensor): ``(batch_size, T, model_dim)`` - enc_out (FloatTensor): ``(batch_size, src_len, model_dim)`` - src_pad_mask (bool): ``(batch_size, 1, src_len)`` - tgt_pad_mask (bool): ``(batch_size, 1, T)`` - step (int or None): stepwise decoding counter - future (bool): If set True, do not apply future_mask. - return_attn (bool) : if set True requires attns output - position_embeddings (FloatTensor): rotary position encodings, if any - - Returns: - (FloatTensor, FloatTensor): - - * layer_out ``(batch_size, T, model_dim)`` - * attns ``(batch_size, head, T, src_len)`` - - """ - attn_mask = None - src_pad_mask = src_pad_mask.unsqueeze(1) # [B,1,1,slen] - - if layer_in.size(1) > 1: - # masking is necessary when sequence length is greater than one - attn_mask = self._compute_attn_mask(tgt_pad_mask, future) - attn_mask = attn_mask.unsqueeze(1) - attn_mask = attn_mask.expand(-1, -1, attn_mask.size(3), -1) - src_pad_mask = src_pad_mask.expand(-1, -1, attn_mask.size(3), -1) - # mask now are (batch x 1 x tlen x s or t len) - # 1 = heads to be expanded in MHA - - norm_layer_in = self.input_layernorm(layer_in) - - self_attn, _ = self.self_attn( - norm_layer_in, - attn_mask=attn_mask, - step=step, - return_attn=return_attn, - position_embeddings=position_embeddings, - ) - - if self.dropout_p > 0: - self_attn = self.dropout(self_attn) - - if self.parallel_residual: - ctx_attn, attns = self.context_attn( - enc_out, - enc_out, - norm_layer_in, - attn_mask=~src_pad_mask, - return_attn=return_attn, - ) - if not self.shared_layer_norm: - norm_res_layer_in = self.residual_layernorm(layer_in) - ff_in = norm_res_layer_in - else: - ff_in = norm_layer_in - else: - norm_query = self.precontext_layernorm(self_attn + layer_in) - ctx_attn, attns = self.context_attn( - enc_out, - enc_out, - norm_query, - attn_mask=~src_pad_mask, - return_attn=return_attn, - ) - if self.dropout_p > 0: - ctx_attn = self.dropout(ctx_attn) - ff_in = self.post_attention_layernorm(ctx_attn + self_attn + layer_in) - # we apply residual with un-normed - layer_out = self.mlp(ff_in) + layer_in + self_attn + ctx_attn - return layer_out, attns - - -class TransformerDecoder(TransformerDecoderBase): - """The Transformer decoder from "Attention is All You Need". - :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` - - Args: - model_config (eole.config.TransformerDecoderConfig): full decoder config - embeddings (eole.modules.Embeddings): - embeddings to use, should have positional encodings - running_config (TrainingConfig / InferenceConfig) - """ - - def __init__( - self, - model_config, - running_config=None, - ): - super(TransformerDecoder, self).__init__( - model_config, running_config=running_config - ) - - self.transformer_layers = nn.ModuleList( - [ - TransformerDecoderLayer( - model_config, - running_config=running_config, - ) - for i in range(model_config.layers) - ] - ) - # This is the Decoder out layer norm - self.layer_norm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) - self._disable_cache() - - def forward(self, emb, **kwargs): - """Decode, possibly stepwise.""" - - assert emb.dim() == 3 # batch x len x embedding_dim - enc_out = kwargs.pop("enc_out", None) - - tgt_pad_mask = kwargs.pop("tgt_pad_mask", None) - assert tgt_pad_mask is not None, "TransformerDecoder requires a tgt pad mask" - src_pad_mask = kwargs.pop("src_pad_mask", None) - assert src_pad_mask is not None, "TransformerDecoder requires a src pad mask" - step = kwargs.pop("step", None) - with_align = kwargs.pop("with_align", False) - return_attn = with_align or kwargs.pop("return_attn", False) - position_embeddings = kwargs.pop("position_embeddings", None) - attn_aligns = [] - - if step == 0: - self._enable_cache(enc_out.device) - - for layer in self.transformer_layers: - emb, attn, attn_align = layer( - emb, - enc_out if enc_out is not None else emb, - src_pad_mask, - tgt_pad_mask, - step=step, - with_align=with_align, - return_attn=return_attn, - position_embeddings=position_embeddings, - ) - if attn_align is not None: - attn_aligns.append(attn_align) - - emb = self.layer_norm(emb) - - attns = {"std": attn} - if with_align: - attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` - # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg - - # TODO change the way attns is returned dict => list or tuple (onnx) - return emb, attns - - def _enable_cache(self, device): - for layer in self.transformer_layers: - # first value set to True triggered by the beginning of decoding - # layer_cache becomes active in the MultiHeadedAttention fwd - layer.context_attn.layer_cache = ( - True, - { - "keys": torch.tensor([], device=device), - "values": torch.tensor([], device=device), - }, - ) - layer.self_attn.layer_cache = ( - True, - { - "keys": torch.tensor([], device=device), - "values": torch.tensor([], device=device), - }, - ) - - def _disable_cache(self): - for layer in self.transformer_layers: - layer.self_attn.layer_cache = ( - False, - {"keys": torch.tensor([]), "values": torch.tensor([])}, - ) - layer.context_attn.layer_cache = ( - False, - {"keys": torch.tensor([]), "values": torch.tensor([])}, - ) diff --git a/eole/decoders/transformer_lm_decoder.py b/eole/decoders/transformer_lm_decoder.py deleted file mode 100644 index 2c72525be..000000000 --- a/eole/decoders/transformer_lm_decoder.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -Implementation of "Attention is All You Need" and of -subsequent transformer based architectures -""" - -import torch -import torch.nn as nn - -from eole.decoders.transformer_base import ( - TransformerDecoderLayerBase, - TransformerDecoderBase, -) -from eole.constants import LayerNorm - - -class TransformerLMDecoderLayer(TransformerDecoderLayerBase): - """Transformer Decoder only layer block in GPT style. - Args: - See TransformerDecoderLayerBase - """ - - def _forward( - self, - layer_in, - pad_mask, - step=None, - future=False, - return_attn=False, - position_embeddings=None, - ): - """A naive forward pass for transformer decoder. - - # T: could be 1 in the case of stepwise decoding or tgt_len - - Args: - layer_in (FloatTensor): ``(batch_size, T, model_dim)`` - pad_mask (bool): ``(batch_size, 1, T)`` - layer_cache (dict or None): cached layer info when stepwise decode - step (int or None): stepwise decoding counter - future (bool): If set True, do not apply future_mask. - return_attn (bool): If set True return attn - position_embeddings (FloatTensor): rotary position encodings, if any - - Returns: - (FloatTensor, FloatTensor): - - * layer_out ``(batch_size, T, model_dim)`` - * attns ``(batch_size, head, T, T)`` - - """ - attn_mask = None - - if layer_in.size(1) > 1: - # Masking is necessary when sequence length is greater than one - # The decoding has not started yet, - # we compute the scores on the source tokens in one shot. - attn_mask = self._compute_attn_mask(pad_mask, future) - attn_mask = attn_mask.unsqueeze(1) - attn_mask = attn_mask.expand(-1, -1, attn_mask.size(3), -1) - # mask now are (batch x 1 x tlen x tlen) - # 1 = heads to be expanded in MHA - - norm_layer_in = self.input_layernorm(layer_in) - - attn_output, attns = self.self_attn( - norm_layer_in, - attn_mask=attn_mask, - step=step, - return_attn=return_attn, - position_embeddings=position_embeddings, - ) - if self.dropout_p > 0: - attn_output = self.dropout(attn_output) - if self.parallel_residual: - # we apply residual with un-normed - if not self.shared_layer_norm: - norm_res_layer_in = self.residual_layernorm(layer_in) - ff_in = norm_res_layer_in - else: - ff_in = norm_layer_in - else: - ff_in = self.post_attention_layernorm(attn_output + layer_in) - # add residual to mlp output - layer_out = self.mlp(ff_in) + layer_in + attn_output - - return layer_out, attns - - -class TransformerLMDecoder(TransformerDecoderBase): - """The Transformer decoder from GPT-2 - Args: - model_config (eole.config.TransformerDecoderConfig): full decoder config - running_config (TrainingConfig / InferenceConfig) - """ - - def __init__( - self, - model_config, - running_config=None, - ): - super(TransformerLMDecoder, self).__init__(model_config) - - self.transformer_layers = nn.ModuleList( - [ - TransformerLMDecoderLayer( - model_config, - running_config=running_config, - ) - for i in range(model_config.layers) - ] - ) - # This is the Decoder out layer norm - self.layer_norm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) - self._disable_cache() - - def forward(self, emb, **kwargs): - """Decode, possibly stepwise.""" - - assert emb.dim() == 3 # batch x len x embedding_dim - pad_mask = kwargs.pop("tgt_pad_mask", None) - assert pad_mask is not None, "TransformerLMDecoder requires a pad mask" - step = kwargs.pop("step", None) - with_align = kwargs.pop("with_align", False) - return_attn = kwargs.pop("return_attn", False) - return_attn = with_align or return_attn - assert not with_align, "TransformerLMDecoder does not support align" - position_embeddings = kwargs.pop("position_embeddings", None) - - if step == 0: - self._enable_cache(emb.device, pad_mask) - - for layer in self.transformer_layers: - emb, attn, _ = layer( - emb, - pad_mask, - step=step, - with_align=with_align, - return_attn=return_attn, - position_embeddings=position_embeddings, - ) - - emb = self.layer_norm(emb) - - attns = {"std": attn} - - # TODO change the way attns is returned dict => list or tuple (onnx) - return emb, attns - - def _enable_cache(self, device, pad_mask): - for layer in self.transformer_layers: - if hasattr(layer, "self_attn"): - layer.self_attn.layer_cache = ( - True, - { - "keys": torch.tensor([], device=device), - "values": torch.tensor([], device=device), - "key_pad_mask": pad_mask, - }, - ) - - def _disable_cache(self): - for layer in self.transformer_layers: - if hasattr(layer, "self_attn"): - layer.self_attn.layer_cache = ( - False, - { - "keys": torch.tensor([]), - "values": torch.tensor([]), - "key_pad_mask": None, - }, - ) diff --git a/eole/models/model.py b/eole/models/model.py index 5e2dfc6f7..faf3e5a28 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -53,7 +53,7 @@ def build_encoder(model_config, running_config=None): ) # full config for now -def build_decoder(model_config, running_config=None): +def build_decoder(model_config, running_config=None, with_cross_attn=False): """ Various decoder dispatcher function. Args: @@ -65,7 +65,9 @@ def build_decoder(model_config, running_config=None): else model_config.decoder.decoder_type ) return str2dec[dec_type].from_config( - model_config.decoder, running_config=running_config + model_config.decoder, + running_config=running_config, + with_cross_attn=with_cross_attn, ) @@ -785,7 +787,9 @@ def build_blocks(cls, model_config, vocabs, running_config=None): share_embeddings=model_config.share_embeddings, src_emb=src_emb, ) - decoder = build_decoder(model_config, running_config=running_config) + decoder = build_decoder( + model_config, running_config=running_config, with_cross_attn=True + ) return cls( encoder=encoder, decoder=decoder, @@ -865,7 +869,9 @@ def __init__(self, **kwargs): @classmethod def build_blocks(cls, model_config, vocabs, running_config=None): tgt_emb = build_tgt_emb(model_config, vocabs, running_config=running_config) - decoder = build_decoder(model_config, running_config=running_config) + decoder = build_decoder( + model_config, running_config=running_config, with_cross_attn=False + ) return cls( decoder=decoder, tgt_emb=tgt_emb, diff --git a/eole/tests/pull_request_check.sh b/eole/tests/pull_request_check.sh index 057992d0c..0d57ac116 100755 --- a/eole/tests/pull_request_check.sh +++ b/eole/tests/pull_request_check.sh @@ -258,7 +258,7 @@ ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/lm_data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ -share_vocab \ - -model '{"hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "layers": 2, "heads": 4}}' \ + -model '{"hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer", "layers": 2, "heads": 4}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10}' \ -report_every 5 \ -src_vocab_size 1000 \ @@ -490,7 +490,7 @@ ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/lm_data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ -tgt_vocab $TMP_OUT_DIR/eole.vocab.src \ - -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "heads": 4}}' \ + -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer", "heads": 4}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "model_path": "'"$TMP_OUT_DIR"'/lm.eole.model", "save_checkpoint_steps": 10}' \ -report_every 5 \ -src_vocab_size 1000 \ @@ -501,7 +501,7 @@ ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/lm_data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ -tgt_vocab $TMP_OUT_DIR/eole.vocab.src \ - -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "heads": 4}}' \ + -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer", "heads": 4}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 20, "train_from": "'"$TMP_OUT_DIR"'/lm.eole.model/step_10", "save_checkpoint_steps": 10, "update_vocab": True, "reset_optim": "states"}' \ -report_every 5 \ -src_vocab_size 1000 \ diff --git a/eole/tests/test_model_lm.yml b/eole/tests/test_model_lm.yml index c992677a5..9c3739ace 100644 --- a/eole/tests/test_model_lm.yml +++ b/eole/tests/test_model_lm.yml @@ -113,8 +113,8 @@ share_vocab: true # feat_vec_exponent=0.7, # model_type='text', # model_dtype='fp32', -# encoder_type='transformer_lm', -# decoder_type='transformer_lm', +# encoder_type='transformer', +# decoder_type='transformer', # layers=-1, # enc_layers=2, # dec_layers=2, diff --git a/eole/tests/test_model_lm/config.json b/eole/tests/test_model_lm/config.json index 5b8fd7d8c..71faf629e 100644 --- a/eole/tests/test_model_lm/config.json +++ b/eole/tests/test_model_lm/config.json @@ -2,7 +2,7 @@ "model": { "decoder": { "hidden_size": 64, - "decoder_type": "transformer_lm", + "decoder_type": "transformer", "add_ffnbias": true, "transformer_ff": 256, "tgt_word_vec_size": 64, From 2d709a40e22153e3a8745815637ce3cc92545bec Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 7 Jan 2025 21:55:38 +0100 Subject: [PATCH 02/11] update docs --- eole/decoders/transformer.py | 50 ++++++++++++++++++++++++------------ eole/encoders/transformer.py | 28 ++++++++++++-------- 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/eole/decoders/transformer.py b/eole/decoders/transformer.py index ab3bce8f8..2fe9fab1f 100644 --- a/eole/decoders/transformer.py +++ b/eole/decoders/transformer.py @@ -1,6 +1,5 @@ """ -Implementation of "Attention is All You Need" and of -subsequent transformer based architectures +Implementation of "Attention is All You Need" Transformer Decoder """ import torch @@ -13,24 +12,29 @@ class TransformerDecoderLayer(nn.Module): + """The Transformer encoder from "Attention is All You Need" + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + + Args: + model_config (eole.config.TransformerEncoderConfig): full encoder config + running_config (TrainingConfig / InferenceConfig) + with_cross_attn (True when used with an encoder) + """ + def __init__( self, model_config, running_config=None, with_cross_attn=False, ): - """ - Args: - model_config (eole.config.TransformerDecoderConfig): full decoder config - """ super(TransformerDecoderLayer, self).__init__() - self.parallel_residual = model_config.parallel_residual self.shared_layer_norm = model_config.shared_layer_norm self.dropout_p = getattr(running_config, "dropout", [0.0])[0] self.full_context_alignment = model_config.full_context_alignment self.alignment_heads = model_config.alignment_heads + # order of layers corresponds to forward flow of tensors self.input_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) @@ -69,19 +73,23 @@ def __init__( ) def forward(self, layer_in, **kwargs): - """Extend `_forward` for (possibly) multiple decoder pass: - Always a default (future masked) decoder forward pass, - Possibly a second future aware decoder pass for joint learn - full context alignement, :cite:`garg2019jointly`. - + """ Args: + layer_in (FloatTensor): ``(batch_size, tgt_len, model_dim)`` + **kwargs + enc_out: (only when using encoder/decoder) + src_pad_mask: (only when using encoder/decoder) + attn_mask (BoolTensor): ``(batch_size, 1, src/tgt_len, tgt_len)`` + step: int + return_attention: bool + position_embeddings (FloatTensor): rotary position encodings, if any Returns: - (FloatTensor, FloatTensor, FloatTensor or None): - - * layer_out ``(batch_size, T, model_dim)`` - * attns ``(batch_size, heads, T, src_len)`` + (FloatTensor, FloatTensor): + * layer_out ``(batch_size, tgt_len, model_dim)`` + * attns """ + enc_out = kwargs.pop("enc_out", None) src_pad_mask = kwargs.pop("src_pad_mask", None) attn_mask = kwargs.pop("attn_mask", None) @@ -139,6 +147,7 @@ def forward(self, layer_in, **kwargs): return layer_out, attns def get_attn_align(self, layer_in, **kwargs): + """:cite:`garg2019jointly`.""" attns = kwargs.pop("attns", None) if self.full_context_alignment: # return _, (B, Q_len, K_len) @@ -162,6 +171,15 @@ def update_dropout(self, dropout, attention_dropout): class TransformerDecoder(DecoderBase): + """The Transformer encoder from "Attention is All You Need" + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + + Args: + model_config (eole.config.TransformerEncoderConfig): full encoder config + running_config (TrainingConfig / InferenceConfig) + with_cross_attn (True when used with an encoder) + """ + def __init__( self, model_config, diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index 95352557f..85128a938 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -1,5 +1,5 @@ """ -Implementation of "Attention is All You Need" +Implementation of "Attention is All You Need" Transformer Encoder """ import torch.nn as nn @@ -25,9 +25,10 @@ def __init__( running_config=None, ): super(TransformerEncoderLayer, self).__init__() - self.parallel_residual = model_config.parallel_residual self.dropout_p = getattr(running_config, "dropout", [0.0])[0] + + # order of layers corresponds to forward flow of tensors self.input_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) @@ -83,15 +84,7 @@ class TransformerEncoder(EncoderBase): Args: model_config (eole.config.TransformerEncoderConfig): full encoder config - embeddings (eole.modules.Embeddings): - embeddings to use, should have positional encodings running_config (TrainingConfig / InferenceConfig) - Returns: - (torch.FloatTensor, torch.FloatTensor): - - * enc_out ``(batch_size, src_len, model_dim)`` - * encoder final state: None in the case of Transformer - * src_len ``(batch_size)`` """ def __init__( @@ -124,7 +117,20 @@ def from_config(cls, model_config, running_config=None): ) def forward(self, emb, **kwargs): - """See :func:`EncoderBase.forward()`""" + """See :func:`EncoderBase.forward()` + + Args: + emb (eole.modules.Embeddings): + embeddings to use, should have positional encodings + **kwargs + pad_mask: ``(batch, maxlen)`` False when value, True when pad + + Returns: + (torch.FloatTensor, torch.FloatTensor): + * enc_out ``(batch_size, src_len, model_dim)`` + * encoder final state: None in the case of Transformer + """ + pad_mask = kwargs.pop("pad_mask", None) assert pad_mask is not None, "TransformerEncoder requires a src pad mask" position_embeddings = kwargs.pop("position_embeddings", None) From 31f05c7f84991ba40bb957a121bd4f778eb90133 Mon Sep 17 00:00:00 2001 From: vince62s Date: Wed, 8 Jan 2025 10:35:51 +0100 Subject: [PATCH 03/11] formatting --- eole/decoders/transformer.py | 70 +++++++++++++++++++----------------- eole/encoders/transformer.py | 4 +-- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/eole/decoders/transformer.py b/eole/decoders/transformer.py index 2fe9fab1f..f5e59b07f 100644 --- a/eole/decoders/transformer.py +++ b/eole/decoders/transformer.py @@ -12,8 +12,7 @@ class TransformerDecoderLayer(nn.Module): - """The Transformer encoder from "Attention is All You Need" - :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` + """Single layer of a transformer decoder Args: model_config (eole.config.TransformerEncoderConfig): full encoder config @@ -43,7 +42,6 @@ def __init__( running_config=running_config, ) self.dropout = nn.Dropout(self.dropout_p) - if with_cross_attn: self.precontext_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps @@ -54,16 +52,13 @@ def __init__( ) else: self.context_attn = None - if model_config.parallel_residual and not model_config.shared_layer_norm: self.residual_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) - self.post_attention_layernorm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) - if model_config.num_experts > 0: self.mlp = MoE(model_config, running_config) else: @@ -187,6 +182,10 @@ def __init__( with_cross_attn=False, ): super(TransformerDecoder, self).__init__() + self.alignment_layer = model_config.alignment_layer + self.with_cross_attn = with_cross_attn + self.sliding_window = model_config.sliding_window + self.transformer_layers = nn.ModuleList( [ TransformerDecoderLayer( @@ -197,14 +196,10 @@ def __init__( for i in range(model_config.layers) ] ) - # This is the Decoder out layer norm self.layer_norm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) self._disable_cache() - self.alignment_layer = model_config.alignment_layer - self.with_cross_attn = with_cross_attn - self.sliding_window = model_config.sliding_window @classmethod def from_config(cls, model_config, running_config=None, with_cross_attn=False): @@ -220,8 +215,6 @@ def init_state(self, **kwargs): pass def map_state(self, fn): - # if self.state["src"] is not None: - # self.state["src"] = fn(self.state["src"], 0) for layer in self.transformer_layers: if self.with_cross_attn: if layer.context_attn.layer_cache[1]["keys"].numel() != 0: @@ -241,14 +234,11 @@ def map_state(self, fn): "key_pad_mask": z, } - def detach_state(self): - raise NotImplementedError - def update_dropout(self, dropout, attention_dropout): for layer in self.transformer_layers: layer.update_dropout(dropout, attention_dropout) - def _compute_attn_mask(self, tgt_pad_mask): + def _causal_attn_mask(self, tgt_pad_mask): tgt_len = tgt_pad_mask.size(-1) # Add triangular future_mask and pad_mask, result mask in (B, T, T). future_mask = torch.tril( @@ -260,14 +250,32 @@ def _compute_attn_mask(self, tgt_pad_mask): if self.sliding_window > 0: future_mask = future_mask.triu_(-self.sliding_window) attn_mask = ~tgt_pad_mask & future_mask.unsqueeze(0) + attn_mask = attn_mask.unsqueeze(1) + attn_mask = attn_mask.expand(-1, -1, attn_mask.size(3), -1) return attn_mask def forward(self, emb, **kwargs): - """Decode, possibly stepwise.""" + """Decode, possibly stepwise. + + Args: + emb (FloatTensor): ``(batch_size, tgt_len, model_dim)`` + **kwargs + enc_out: (only when using encoder/decoder) + src_pad_mask: (only when using encoder/decoder) + tgt_pad_mask (BoolTensor): ``(batch_size, tgt_len)`` + used to build the attention mask + step: int + with_aligh: bool + return_attention: bool, set by `attn_debug` + position_embeddings (FloatTensor): rotary position encodings, if any + Returns: + (FloatTensor, FloatTensor): + * emb: decoder_out ``(batch_size, tgt_len, model_dim)`` + * attns: standard and align if set + """ assert emb.dim() == 3 # batch x len x embedding_dim enc_out = kwargs.pop("enc_out", None) - tgt_pad_mask = kwargs.pop("tgt_pad_mask", None) assert tgt_pad_mask is not None, "TransformerDecoder requires a tgt pad mask" src_pad_mask = kwargs.pop("src_pad_mask", None) @@ -282,16 +290,14 @@ def forward(self, emb, **kwargs): attn_aligns = [] if emb.size(1) > 1: - # masking is necessary when sequence length is greater than one - attn_mask = self._compute_attn_mask(tgt_pad_mask) - attn_mask = attn_mask.unsqueeze(1) - attn_mask = attn_mask.expand(-1, -1, attn_mask.size(3), -1) + # causal masking is necessary when sequence length is greater than one + attn_mask = self._causal_attn_mask(tgt_pad_mask) if self.with_cross_attn: src_pad_mask = src_pad_mask.unsqueeze(1).expand( -1, -1, attn_mask.size(3), -1 ) # mask now are (batch x 1 x tlen x s or t len) - # 1 = heads to be expanded in MHA + # dim 1 to be broadcasted to heads in MHA else: attn_mask = None if self.with_cross_attn: @@ -331,23 +337,13 @@ def forward(self, emb, **kwargs): attns = {"std": top_attn} if with_align: attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)` - # attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg - # TODO change the way attns is returned dict => list or tuple (onnx) return emb, attns def _enable_cache(self, device, pad_mask): for layer in self.transformer_layers: # first value set to True triggered by the beginning of decoding # layer_cache becomes active in the MultiHeadedAttention fwd - if layer.context_attn: - layer.context_attn.layer_cache = ( - True, - { - "keys": torch.tensor([], device=device), - "values": torch.tensor([], device=device), - }, - ) layer.self_attn.layer_cache = ( True, { @@ -356,6 +352,14 @@ def _enable_cache(self, device, pad_mask): "key_pad_mask": pad_mask, }, ) + if layer.context_attn: + layer.context_attn.layer_cache = ( + True, + { + "keys": torch.tensor([], device=device), + "values": torch.tensor([], device=device), + }, + ) def _disable_cache(self): for layer in self.transformer_layers: diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index 85128a938..9e1a1ade6 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -93,7 +93,6 @@ def __init__( running_config=None, ): super(TransformerEncoder, self).__init__() - self.transformer_layers = nn.ModuleList( [ TransformerEncoderLayer( @@ -103,7 +102,6 @@ def __init__( for i in range(model_config.layers) ] ) - # This is the Encoder out layer norm self.layer_norm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) @@ -112,7 +110,7 @@ def __init__( def from_config(cls, model_config, running_config=None): """Alternate constructor.""" return cls( - model_config, # TransformerEncoderConfig + model_config, running_config, ) From 9725332b788a65e55cf564476a32897e59636c0d Mon Sep 17 00:00:00 2001 From: vince62s Date: Wed, 8 Jan 2025 10:48:17 +0100 Subject: [PATCH 04/11] merge main --- eole/bin/convert/convert_T5.py | 42 +++++++------- eole/bin/convert/convert_falcon.py | 48 +++++++-------- eole/bin/convert/convert_llama.py | 84 ++++++++++++--------------- eole/bin/convert/convert_mpt.py | 12 ++-- eole/bin/convert/convert_redpajama.py | 6 +- eole/bin/convert/convert_xgen.py | 12 ++-- eole/config/models.py | 20 +------ eole/decoders/transformer.py | 20 ++----- eole/encoders/transformer.py | 8 +-- eole/models/model.py | 16 +---- eole/tests/test_transform.py | 8 +-- 11 files changed, 113 insertions(+), 163 deletions(-) diff --git a/eole/bin/convert/convert_T5.py b/eole/bin/convert/convert_T5.py index 36a65f638..8911c8f5c 100644 --- a/eole/bin/convert/convert_T5.py +++ b/eole/bin/convert/convert_T5.py @@ -129,12 +129,12 @@ def run(cls, args): eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight"] = checkpoint[ "decoder.block." + str(i) + ".layer.0.SelfAttention.k.weight" ].to(torch.float16) - eole_safetensor["encoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( - checkpoint["encoder.block." + str(i) + ".layer.0.SelfAttention.v.weight"].to(torch.float16) - ) - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( - checkpoint["decoder.block." + str(i) + ".layer.0.SelfAttention.v.weight"].to(torch.float16) - ) + eole_safetensor[ + "encoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" + ] = checkpoint["encoder.block." + str(i) + ".layer.0.SelfAttention.v.weight"].to(torch.float16) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" + ] = checkpoint["decoder.block." + str(i) + ".layer.0.SelfAttention.v.weight"].to(torch.float16) eole_safetensor["encoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = checkpoint[ "encoder.block." + str(i) + ".layer.0.SelfAttention.o.weight" @@ -171,15 +171,15 @@ def run(cls, args): eole_safetensor["decoder.transformer_layers." + str(i) + ".context_attn.linear_query.weight"] = ( checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.q.weight"] / (dimperhead**-0.5) ).to(torch.float16) - eole_safetensor["decoder.transformer_layers." + str(i) + ".context_attn.linear_keys.weight"] = ( - checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.k.weight"].to(torch.float16) - ) - eole_safetensor["decoder.transformer_layers." + str(i) + ".context_attn.linear_values.weight"] = ( - checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.v.weight"].to(torch.float16) - ) - eole_safetensor["decoder.transformer_layers." + str(i) + ".context_attn.final_linear.weight"] = ( - checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.o.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".context_attn.linear_keys.weight" + ] = checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.k.weight"].to(torch.float16) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".context_attn.linear_values.weight" + ] = checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.v.weight"].to(torch.float16) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".context_attn.final_linear.weight" + ] = checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.o.weight"].to(torch.float16) eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_2.weight"] = checkpoint[ "decoder.block." + str(i) + ".layer.1.layer_norm.weight" @@ -196,12 +196,12 @@ def run(cls, args): "decoder.block." + str(i) + ".layer.2.DenseReluDense.wi_1.weight" ].to(torch.float16) - eole_safetensor["encoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( - checkpoint["encoder.block." + str(i) + ".layer.1.layer_norm.weight"].to(torch.float16) - ) - eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( - checkpoint["decoder.block." + str(i) + ".layer.2.layer_norm.weight"].to(torch.float16) - ) + eole_safetensor[ + "encoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" + ] = checkpoint["encoder.block." + str(i) + ".layer.1.layer_norm.weight"].to(torch.float16) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" + ] = checkpoint["decoder.block." + str(i) + ".layer.2.layer_norm.weight"].to(torch.float16) if shard == 0: transformer_ff = eole_safetensor["decoder.transformer_layers.0.feed_forward.w_1.weight"].size(0) diff --git a/eole/bin/convert/convert_falcon.py b/eole/bin/convert/convert_falcon.py index 57d075a1c..6a75a2215 100644 --- a/eole/bin/convert/convert_falcon.py +++ b/eole/bin/convert/convert_falcon.py @@ -146,48 +146,48 @@ def run(cls, args): .reshape(hidden_size // heads * num_kv, hidden_size) # [512, 8192] ou [64, 4544] ) - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( - qkv_W[:, [-1]].reshape(hidden_size // heads * num_kv, hidden_size) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" + ] = qkv_W[:, [-1]].reshape(hidden_size // heads * num_kv, hidden_size) if "transformer.h." + str(i) + ".self_attention.dense.weight" in checkpoint.keys(): - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = ( - checkpoint["transformer.h." + str(i) + ".self_attention.dense.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight" + ] = checkpoint["transformer.h." + str(i) + ".self_attention.dense.weight"].to(torch.float16) if shared_layer: if "transformer.h." + str(i) + ".input_layernorm.weight" in checkpoint.keys(): - eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_1.weight"] = ( - checkpoint["transformer.h." + str(i) + ".input_layernorm.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".layer_norm_1.weight" + ] = checkpoint["transformer.h." + str(i) + ".input_layernorm.weight"].to(torch.float16) if "transformer.h." + str(i) + ".input_layernorm.bias" in checkpoint.keys(): eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_1.bias"] = checkpoint[ "transformer.h." + str(i) + ".input_layernorm.bias" ].to(torch.float16) else: if "transformer.h." + str(i) + ".ln_attn.weight" in checkpoint.keys(): - eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_1.weight"] = ( - checkpoint["transformer.h." + str(i) + ".ln_attn.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".layer_norm_1.weight" + ] = checkpoint["transformer.h." + str(i) + ".ln_attn.weight"].to(torch.float16) if "transformer.h." + str(i) + ".ln_attn.weight" in checkpoint.keys(): eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_1.bias"] = checkpoint[ "transformer.h." + str(i) + ".ln_attn.bias" ].to(torch.float16) if "transformer.h." + str(i) + ".ln_mlp.weight" in checkpoint.keys(): - eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_res.weight"] = ( - checkpoint["transformer.h." + str(i) + ".ln_mlp.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".layer_norm_res.weight" + ] = checkpoint["transformer.h." + str(i) + ".ln_mlp.weight"].to(torch.float16) if "transformer.h." + str(i) + ".ln_mlp.bias" in checkpoint.keys(): - eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_res.bias"] = ( - checkpoint["transformer.h." + str(i) + ".ln_mlp.bias"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".layer_norm_res.bias" + ] = checkpoint["transformer.h." + str(i) + ".ln_mlp.bias"].to(torch.float16) if "transformer.h." + str(i) + ".mlp.dense_h_to_4h.weight" in checkpoint.keys(): - eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.w_1.weight"] = ( - checkpoint["transformer.h." + str(i) + ".mlp.dense_h_to_4h.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".feed_forward.w_1.weight" + ] = checkpoint["transformer.h." + str(i) + ".mlp.dense_h_to_4h.weight"].to(torch.float16) if "transformer.h." + str(i) + ".mlp.dense_4h_to_h.weight" in checkpoint.keys(): - eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.w_2.weight"] = ( - checkpoint["transformer.h." + str(i) + ".mlp.dense_4h_to_h.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".feed_forward.w_2.weight" + ] = checkpoint["transformer.h." + str(i) + ".mlp.dense_4h_to_h.weight"].to(torch.float16) if shard == 0: transformer_ff = eole_safetensor["decoder.transformer_layers.0.feed_forward.w_1.weight"].size(0) diff --git a/eole/bin/convert/convert_llama.py b/eole/bin/convert/convert_llama.py index 9d1e5d998..b0ecea87c 100644 --- a/eole/bin/convert/convert_llama.py +++ b/eole/bin/convert/convert_llama.py @@ -104,9 +104,9 @@ def run(cls, args): "layers." + str(i) + ".attention.wk.weight" ] - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( - checkpoint["layers." + str(i) + ".attention.wv.weight"] - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" + ] = checkpoint["layers." + str(i) + ".attention.wv.weight"] eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight"] = checkpoint[ "layers." + str(i) + ".attention.wq.weight" @@ -132,9 +132,9 @@ def run(cls, args): "layers." + str(i) + ".feed_forward.w3.weight" ] - eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( - checkpoint["layers." + str(i) + ".ffn_norm.weight"].clone() - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" + ] = checkpoint["layers." + str(i) + ".ffn_norm.weight"].clone() input_shard = 1 while os.path.exists(os.path.join(args.model_dir, "consolidated.0" + str(input_shard) + ".pth")): @@ -169,52 +169,44 @@ def run(cls, args): min(-(decoder_layers // -args.nshards) * (shard + 1), decoder_layers), 1, ): - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight"] = ( - torch.cat( - ( - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight" - ], - checkpoint["layers." + str(i) + ".attention.wk.weight"], - ), - dim=0, - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight" + ] = torch.cat( + ( + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight"], + checkpoint["layers." + str(i) + ".attention.wk.weight"], + ), + dim=0, ) - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( - torch.cat( - ( - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" - ], - checkpoint["layers." + str(i) + ".attention.wv.weight"], - ), - dim=0, - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" + ] = torch.cat( + ( + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"], + checkpoint["layers." + str(i) + ".attention.wv.weight"], + ), + dim=0, ) - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight"] = ( - torch.cat( - ( - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight" - ], - checkpoint["layers." + str(i) + ".attention.wq.weight"], - ), - dim=0, - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight" + ] = torch.cat( + ( + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight"], + checkpoint["layers." + str(i) + ".attention.wq.weight"], + ), + dim=0, ) - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = ( - torch.cat( - ( - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight" - ], - checkpoint["layers." + str(i) + ".attention.wo.weight"], - ), - dim=1, - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight" + ] = torch.cat( + ( + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"], + checkpoint["layers." + str(i) + ".attention.wo.weight"], + ), + dim=1, ) eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.w_1.weight"] = torch.cat( diff --git a/eole/bin/convert/convert_mpt.py b/eole/bin/convert/convert_mpt.py index f23cbe80e..e0d8613e3 100755 --- a/eole/bin/convert/convert_mpt.py +++ b/eole/bin/convert/convert_mpt.py @@ -89,9 +89,9 @@ def run(cls, args): eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight"] = checkpoint[ "transformer.blocks." + str(i) + ".attn.Wqkv.weight" ][hidden_size : (hidden_size * 2), :].clone() - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( - checkpoint["transformer.blocks." + str(i) + ".attn.Wqkv.weight"][(hidden_size * 2) :, :].clone() - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" + ] = checkpoint["transformer.blocks." + str(i) + ".attn.Wqkv.weight"][(hidden_size * 2) :, :].clone() eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = checkpoint[ "transformer.blocks." + str(i) + ".attn.out_proj.weight" @@ -113,9 +113,9 @@ def run(cls, args): "transformer.blocks." + str(i) + ".ffn.down_proj.weight" ] - eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( - checkpoint["transformer.blocks." + str(i) + ".norm_2.weight"] - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" + ] = checkpoint["transformer.blocks." + str(i) + ".norm_2.weight"] eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.bias"] = torch.zeros( eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"].size(0), dtype=torch.float16, diff --git a/eole/bin/convert/convert_redpajama.py b/eole/bin/convert/convert_redpajama.py index b80ecda87..1be73f792 100755 --- a/eole/bin/convert/convert_redpajama.py +++ b/eole/bin/convert/convert_redpajama.py @@ -156,9 +156,9 @@ def run(cls, args): "gpt_neox.layers." + str(i) + ".mlp.dense_4h_to_h.bias" ] - eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( - checkpoint["gpt_neox.layers." + str(i) + ".post_attention_layernorm.weight"] - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" + ] = checkpoint["gpt_neox.layers." + str(i) + ".post_attention_layernorm.weight"] eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.bias"] = checkpoint[ "gpt_neox.layers." + str(i) + ".post_attention_layernorm.bias" ] diff --git a/eole/bin/convert/convert_xgen.py b/eole/bin/convert/convert_xgen.py index 1a5debeae..656b6e4bc 100755 --- a/eole/bin/convert/convert_xgen.py +++ b/eole/bin/convert/convert_xgen.py @@ -100,9 +100,9 @@ def run(cls, args): .transpose(1, 2) .reshape(hidden_size, hidden_size) ).to(torch.float16) - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( - checkpoint["model.layers." + str(i) + ".self_attn.v_proj.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" + ] = checkpoint["model.layers." + str(i) + ".self_attn.v_proj.weight"].to(torch.float16) eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = checkpoint[ "model.layers." + str(i) + ".self_attn.o_proj.weight" @@ -123,9 +123,9 @@ def run(cls, args): "model.layers." + str(i) + ".mlp.up_proj.weight" ].to(torch.float16) - eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( - checkpoint["model.layers." + str(i) + ".post_attention_layernorm.weight"].to(torch.float16) - ) + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" + ] = checkpoint["model.layers." + str(i) + ".post_attention_layernorm.weight"].to(torch.float16) if shard == 0: transformer_ff = eole_safetensor["decoder.transformer_layers.0.feed_forward.w_1.weight"].size(0) diff --git a/eole/config/models.py b/eole/config/models.py index dc6f92f56..56441495e 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -328,24 +328,12 @@ class BaseModelConfig(Config): default_factory=EmbeddingsConfig, description="Contains most of the args useful to build the Embeddings module.", ) - encoder: ( - Union[ - TransformerEncoderConfig, - RnnEncoderConfig, - CnnEncoderConfig, - MeanEncoderConfig, - ] - | None - ) = Field( + encoder: (Union[TransformerEncoderConfig, RnnEncoderConfig, CnnEncoderConfig, MeanEncoderConfig,] | None) = Field( default=None, discriminator="encoder_type", description="Major parameters of an encoder.", ) # we shall use discriminators here - decoder: Union[ - TransformerDecoderConfig, - RnnDecoderConfig, - CnnDecoderConfig, - ] | None = Field( + decoder: Union[TransformerDecoderConfig, RnnDecoderConfig, CnnDecoderConfig,] | None = Field( default=None, discriminator="decoder_type", description="Major parameters of a decoder.", @@ -475,9 +463,7 @@ def update_model_opts(self): self.encoder.update(**update_dict.pop("encoder")) if self.decoder is not None: - update_dict["decoder"] = { - "tgt_word_vec_size": self.embeddings.tgt_word_vec_size - } + update_dict["decoder"] = {"tgt_word_vec_size": self.embeddings.tgt_word_vec_size} if getattr(self.decoder, "decoder_type", None) == "transformer": update_dict["decoder"].update( { diff --git a/eole/decoders/transformer.py b/eole/decoders/transformer.py index f5e59b07f..aea056b09 100644 --- a/eole/decoders/transformer.py +++ b/eole/decoders/transformer.py @@ -34,9 +34,7 @@ def __init__( self.alignment_heads = model_config.alignment_heads # order of layers corresponds to forward flow of tensors - self.input_layernorm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) + self.input_layernorm = LayerNorm[model_config.layer_norm](model_config.hidden_size, eps=model_config.norm_eps) self.self_attn = SelfMHA( model_config, running_config=running_config, @@ -196,9 +194,7 @@ def __init__( for i in range(model_config.layers) ] ) - self.layer_norm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) + self.layer_norm = LayerNorm[model_config.layer_norm](model_config.hidden_size, eps=model_config.norm_eps) self._disable_cache() @classmethod @@ -242,9 +238,7 @@ def _causal_attn_mask(self, tgt_pad_mask): tgt_len = tgt_pad_mask.size(-1) # Add triangular future_mask and pad_mask, result mask in (B, T, T). future_mask = torch.tril( - torch.ones( - (tgt_len, tgt_len), device=tgt_pad_mask.device, dtype=torch.bool - ), + torch.ones((tgt_len, tgt_len), device=tgt_pad_mask.device, dtype=torch.bool), diagonal=0, ) if self.sliding_window > 0: @@ -280,9 +274,7 @@ def forward(self, emb, **kwargs): assert tgt_pad_mask is not None, "TransformerDecoder requires a tgt pad mask" src_pad_mask = kwargs.pop("src_pad_mask", None) if self.with_cross_attn: - assert ( - src_pad_mask is not None - ), "TransformerDecoder requires a src pad mask" + assert src_pad_mask is not None, "TransformerDecoder requires a src pad mask" step = kwargs.pop("step", None) with_align = kwargs.pop("with_align", False) return_attn = with_align or kwargs.pop("return_attn", False) @@ -293,9 +285,7 @@ def forward(self, emb, **kwargs): # causal masking is necessary when sequence length is greater than one attn_mask = self._causal_attn_mask(tgt_pad_mask) if self.with_cross_attn: - src_pad_mask = src_pad_mask.unsqueeze(1).expand( - -1, -1, attn_mask.size(3), -1 - ) + src_pad_mask = src_pad_mask.unsqueeze(1).expand(-1, -1, attn_mask.size(3), -1) # mask now are (batch x 1 x tlen x s or t len) # dim 1 to be broadcasted to heads in MHA else: diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index 591ddfec4..78cbda900 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -29,9 +29,7 @@ def __init__( self.dropout_p = getattr(running_config, "dropout", [0.0])[0] # order of layers corresponds to forward flow of tensors - self.input_layernorm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) + self.input_layernorm = LayerNorm[model_config.layer_norm](model_config.hidden_size, eps=model_config.norm_eps) self.self_attn = SelfMHA( model_config, running_config=running_config, @@ -100,9 +98,7 @@ def __init__( for i in range(model_config.layers) ] ) - self.layer_norm = LayerNorm[model_config.layer_norm]( - model_config.hidden_size, eps=model_config.norm_eps - ) + self.layer_norm = LayerNorm[model_config.layer_norm](model_config.hidden_size, eps=model_config.norm_eps) @classmethod def from_config(cls, model_config, running_config=None): diff --git a/eole/models/model.py b/eole/models/model.py index b8571d5ac..e46efa4e7 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -504,13 +504,7 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset) else: row_slice_start = 0 row_slice_end = param.data.size(1) - assert ( - param.data.size() - == ckpt_t[ - col_slice_start:col_slice_end, - row_slice_start:row_slice_end, - ].size() - ), ( + assert param.data.size() == ckpt_t[col_slice_start:col_slice_end, row_slice_start:row_slice_end,].size(), ( "An error in model's partition and checkpoint's slice was detected, " f"[{name}, {module}, {param_name}, {param.data.size()}, {ckpt_t.size()}]" ) @@ -715,9 +709,7 @@ def build_blocks(cls, model_config, vocabs, running_config=None): share_embeddings=model_config.share_embeddings, src_emb=src_emb, ) - decoder = build_decoder( - model_config, running_config=running_config, with_cross_attn=True - ) + decoder = build_decoder(model_config, running_config=running_config, with_cross_attn=True) return cls( encoder=encoder, decoder=decoder, @@ -797,9 +789,7 @@ def __init__(self, **kwargs): @classmethod def build_blocks(cls, model_config, vocabs, running_config=None): tgt_emb = build_tgt_emb(model_config, vocabs, running_config=running_config) - decoder = build_decoder( - model_config, running_config=running_config, with_cross_attn=False - ) + decoder = build_decoder(model_config, running_config=running_config, with_cross_attn=False) return cls( decoder=decoder, tgt_emb=tgt_emb, diff --git a/eole/tests/test_transform.py b/eole/tests/test_transform.py index a5dd97f52..b602fb75f 100644 --- a/eole/tests/test_transform.py +++ b/eole/tests/test_transform.py @@ -370,9 +370,7 @@ def test_pyonmttok_bpe(self): # Test mask location ex = { - "src": ( - "### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" "### Response : ⦅newline⦆response" - ), + "src": ("### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" "### Response : ⦅newline⦆response"), "tgt": "", } ex["src"] = ex["src"].split(" ") @@ -445,9 +443,7 @@ def test_pyonmttok_sp(self): # Test mask location ex = { - "src": ( - "### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" "### Response : ⦅newline⦆response" - ), + "src": ("### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" "### Response : ⦅newline⦆response"), "tgt": "", } ex["src"] = ex["src"].split(" ") From f2211bcd6984871d7a5da8ec028df42115b0c239 Mon Sep 17 00:00:00 2001 From: vince62s Date: Wed, 8 Jan 2025 10:59:11 +0100 Subject: [PATCH 05/11] black 24 --- eole/bin/convert/convert_T5.py | 42 +++++++------- eole/bin/convert/convert_falcon.py | 48 +++++++-------- eole/bin/convert/convert_llama.py | 84 +++++++++++++++------------ eole/bin/convert/convert_mpt.py | 12 ++-- eole/bin/convert/convert_redpajama.py | 6 +- eole/bin/convert/convert_xgen.py | 12 ++-- eole/config/models.py | 19 +++++- eole/models/model.py | 8 ++- eole/tests/test_transform.py | 8 ++- 9 files changed, 136 insertions(+), 103 deletions(-) diff --git a/eole/bin/convert/convert_T5.py b/eole/bin/convert/convert_T5.py index 8911c8f5c..36a65f638 100644 --- a/eole/bin/convert/convert_T5.py +++ b/eole/bin/convert/convert_T5.py @@ -129,12 +129,12 @@ def run(cls, args): eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight"] = checkpoint[ "decoder.block." + str(i) + ".layer.0.SelfAttention.k.weight" ].to(torch.float16) - eole_safetensor[ - "encoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" - ] = checkpoint["encoder.block." + str(i) + ".layer.0.SelfAttention.v.weight"].to(torch.float16) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" - ] = checkpoint["decoder.block." + str(i) + ".layer.0.SelfAttention.v.weight"].to(torch.float16) + eole_safetensor["encoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( + checkpoint["encoder.block." + str(i) + ".layer.0.SelfAttention.v.weight"].to(torch.float16) + ) + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( + checkpoint["decoder.block." + str(i) + ".layer.0.SelfAttention.v.weight"].to(torch.float16) + ) eole_safetensor["encoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = checkpoint[ "encoder.block." + str(i) + ".layer.0.SelfAttention.o.weight" @@ -171,15 +171,15 @@ def run(cls, args): eole_safetensor["decoder.transformer_layers." + str(i) + ".context_attn.linear_query.weight"] = ( checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.q.weight"] / (dimperhead**-0.5) ).to(torch.float16) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".context_attn.linear_keys.weight" - ] = checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.k.weight"].to(torch.float16) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".context_attn.linear_values.weight" - ] = checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.v.weight"].to(torch.float16) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".context_attn.final_linear.weight" - ] = checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.o.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".context_attn.linear_keys.weight"] = ( + checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.k.weight"].to(torch.float16) + ) + eole_safetensor["decoder.transformer_layers." + str(i) + ".context_attn.linear_values.weight"] = ( + checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.v.weight"].to(torch.float16) + ) + eole_safetensor["decoder.transformer_layers." + str(i) + ".context_attn.final_linear.weight"] = ( + checkpoint["decoder.block." + str(i) + ".layer.1.EncDecAttention.o.weight"].to(torch.float16) + ) eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_2.weight"] = checkpoint[ "decoder.block." + str(i) + ".layer.1.layer_norm.weight" @@ -196,12 +196,12 @@ def run(cls, args): "decoder.block." + str(i) + ".layer.2.DenseReluDense.wi_1.weight" ].to(torch.float16) - eole_safetensor[ - "encoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" - ] = checkpoint["encoder.block." + str(i) + ".layer.1.layer_norm.weight"].to(torch.float16) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" - ] = checkpoint["decoder.block." + str(i) + ".layer.2.layer_norm.weight"].to(torch.float16) + eole_safetensor["encoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( + checkpoint["encoder.block." + str(i) + ".layer.1.layer_norm.weight"].to(torch.float16) + ) + eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( + checkpoint["decoder.block." + str(i) + ".layer.2.layer_norm.weight"].to(torch.float16) + ) if shard == 0: transformer_ff = eole_safetensor["decoder.transformer_layers.0.feed_forward.w_1.weight"].size(0) diff --git a/eole/bin/convert/convert_falcon.py b/eole/bin/convert/convert_falcon.py index 6a75a2215..57d075a1c 100644 --- a/eole/bin/convert/convert_falcon.py +++ b/eole/bin/convert/convert_falcon.py @@ -146,48 +146,48 @@ def run(cls, args): .reshape(hidden_size // heads * num_kv, hidden_size) # [512, 8192] ou [64, 4544] ) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" - ] = qkv_W[:, [-1]].reshape(hidden_size // heads * num_kv, hidden_size) + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( + qkv_W[:, [-1]].reshape(hidden_size // heads * num_kv, hidden_size) + ) if "transformer.h." + str(i) + ".self_attention.dense.weight" in checkpoint.keys(): - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight" - ] = checkpoint["transformer.h." + str(i) + ".self_attention.dense.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = ( + checkpoint["transformer.h." + str(i) + ".self_attention.dense.weight"].to(torch.float16) + ) if shared_layer: if "transformer.h." + str(i) + ".input_layernorm.weight" in checkpoint.keys(): - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".layer_norm_1.weight" - ] = checkpoint["transformer.h." + str(i) + ".input_layernorm.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_1.weight"] = ( + checkpoint["transformer.h." + str(i) + ".input_layernorm.weight"].to(torch.float16) + ) if "transformer.h." + str(i) + ".input_layernorm.bias" in checkpoint.keys(): eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_1.bias"] = checkpoint[ "transformer.h." + str(i) + ".input_layernorm.bias" ].to(torch.float16) else: if "transformer.h." + str(i) + ".ln_attn.weight" in checkpoint.keys(): - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".layer_norm_1.weight" - ] = checkpoint["transformer.h." + str(i) + ".ln_attn.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_1.weight"] = ( + checkpoint["transformer.h." + str(i) + ".ln_attn.weight"].to(torch.float16) + ) if "transformer.h." + str(i) + ".ln_attn.weight" in checkpoint.keys(): eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_1.bias"] = checkpoint[ "transformer.h." + str(i) + ".ln_attn.bias" ].to(torch.float16) if "transformer.h." + str(i) + ".ln_mlp.weight" in checkpoint.keys(): - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".layer_norm_res.weight" - ] = checkpoint["transformer.h." + str(i) + ".ln_mlp.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_res.weight"] = ( + checkpoint["transformer.h." + str(i) + ".ln_mlp.weight"].to(torch.float16) + ) if "transformer.h." + str(i) + ".ln_mlp.bias" in checkpoint.keys(): - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".layer_norm_res.bias" - ] = checkpoint["transformer.h." + str(i) + ".ln_mlp.bias"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".layer_norm_res.bias"] = ( + checkpoint["transformer.h." + str(i) + ".ln_mlp.bias"].to(torch.float16) + ) if "transformer.h." + str(i) + ".mlp.dense_h_to_4h.weight" in checkpoint.keys(): - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".feed_forward.w_1.weight" - ] = checkpoint["transformer.h." + str(i) + ".mlp.dense_h_to_4h.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.w_1.weight"] = ( + checkpoint["transformer.h." + str(i) + ".mlp.dense_h_to_4h.weight"].to(torch.float16) + ) if "transformer.h." + str(i) + ".mlp.dense_4h_to_h.weight" in checkpoint.keys(): - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".feed_forward.w_2.weight" - ] = checkpoint["transformer.h." + str(i) + ".mlp.dense_4h_to_h.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.w_2.weight"] = ( + checkpoint["transformer.h." + str(i) + ".mlp.dense_4h_to_h.weight"].to(torch.float16) + ) if shard == 0: transformer_ff = eole_safetensor["decoder.transformer_layers.0.feed_forward.w_1.weight"].size(0) diff --git a/eole/bin/convert/convert_llama.py b/eole/bin/convert/convert_llama.py index b0ecea87c..9d1e5d998 100644 --- a/eole/bin/convert/convert_llama.py +++ b/eole/bin/convert/convert_llama.py @@ -104,9 +104,9 @@ def run(cls, args): "layers." + str(i) + ".attention.wk.weight" ] - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" - ] = checkpoint["layers." + str(i) + ".attention.wv.weight"] + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( + checkpoint["layers." + str(i) + ".attention.wv.weight"] + ) eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight"] = checkpoint[ "layers." + str(i) + ".attention.wq.weight" @@ -132,9 +132,9 @@ def run(cls, args): "layers." + str(i) + ".feed_forward.w3.weight" ] - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" - ] = checkpoint["layers." + str(i) + ".ffn_norm.weight"].clone() + eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( + checkpoint["layers." + str(i) + ".ffn_norm.weight"].clone() + ) input_shard = 1 while os.path.exists(os.path.join(args.model_dir, "consolidated.0" + str(input_shard) + ".pth")): @@ -169,44 +169,52 @@ def run(cls, args): min(-(decoder_layers // -args.nshards) * (shard + 1), decoder_layers), 1, ): - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight" - ] = torch.cat( - ( - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight"], - checkpoint["layers." + str(i) + ".attention.wk.weight"], - ), - dim=0, + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight"] = ( + torch.cat( + ( + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight" + ], + checkpoint["layers." + str(i) + ".attention.wk.weight"], + ), + dim=0, + ) ) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" - ] = torch.cat( - ( - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"], - checkpoint["layers." + str(i) + ".attention.wv.weight"], - ), - dim=0, + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( + torch.cat( + ( + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" + ], + checkpoint["layers." + str(i) + ".attention.wv.weight"], + ), + dim=0, + ) ) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight" - ] = torch.cat( - ( - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight"], - checkpoint["layers." + str(i) + ".attention.wq.weight"], - ), - dim=0, + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight"] = ( + torch.cat( + ( + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.linear_query.weight" + ], + checkpoint["layers." + str(i) + ".attention.wq.weight"], + ), + dim=0, + ) ) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight" - ] = torch.cat( - ( - eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"], - checkpoint["layers." + str(i) + ".attention.wo.weight"], - ), - dim=1, + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = ( + torch.cat( + ( + eole_safetensor[ + "decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight" + ], + checkpoint["layers." + str(i) + ".attention.wo.weight"], + ), + dim=1, + ) ) eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.w_1.weight"] = torch.cat( diff --git a/eole/bin/convert/convert_mpt.py b/eole/bin/convert/convert_mpt.py index e0d8613e3..f23cbe80e 100755 --- a/eole/bin/convert/convert_mpt.py +++ b/eole/bin/convert/convert_mpt.py @@ -89,9 +89,9 @@ def run(cls, args): eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_keys.weight"] = checkpoint[ "transformer.blocks." + str(i) + ".attn.Wqkv.weight" ][hidden_size : (hidden_size * 2), :].clone() - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" - ] = checkpoint["transformer.blocks." + str(i) + ".attn.Wqkv.weight"][(hidden_size * 2) :, :].clone() + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( + checkpoint["transformer.blocks." + str(i) + ".attn.Wqkv.weight"][(hidden_size * 2) :, :].clone() + ) eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = checkpoint[ "transformer.blocks." + str(i) + ".attn.out_proj.weight" @@ -113,9 +113,9 @@ def run(cls, args): "transformer.blocks." + str(i) + ".ffn.down_proj.weight" ] - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" - ] = checkpoint["transformer.blocks." + str(i) + ".norm_2.weight"] + eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( + checkpoint["transformer.blocks." + str(i) + ".norm_2.weight"] + ) eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.bias"] = torch.zeros( eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"].size(0), dtype=torch.float16, diff --git a/eole/bin/convert/convert_redpajama.py b/eole/bin/convert/convert_redpajama.py index 1be73f792..b80ecda87 100755 --- a/eole/bin/convert/convert_redpajama.py +++ b/eole/bin/convert/convert_redpajama.py @@ -156,9 +156,9 @@ def run(cls, args): "gpt_neox.layers." + str(i) + ".mlp.dense_4h_to_h.bias" ] - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" - ] = checkpoint["gpt_neox.layers." + str(i) + ".post_attention_layernorm.weight"] + eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( + checkpoint["gpt_neox.layers." + str(i) + ".post_attention_layernorm.weight"] + ) eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.bias"] = checkpoint[ "gpt_neox.layers." + str(i) + ".post_attention_layernorm.bias" ] diff --git a/eole/bin/convert/convert_xgen.py b/eole/bin/convert/convert_xgen.py index 656b6e4bc..1a5debeae 100755 --- a/eole/bin/convert/convert_xgen.py +++ b/eole/bin/convert/convert_xgen.py @@ -100,9 +100,9 @@ def run(cls, args): .transpose(1, 2) .reshape(hidden_size, hidden_size) ).to(torch.float16) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight" - ] = checkpoint["model.layers." + str(i) + ".self_attn.v_proj.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.linear_values.weight"] = ( + checkpoint["model.layers." + str(i) + ".self_attn.v_proj.weight"].to(torch.float16) + ) eole_safetensor["decoder.transformer_layers." + str(i) + ".self_attn.final_linear.weight"] = checkpoint[ "model.layers." + str(i) + ".self_attn.o_proj.weight" @@ -123,9 +123,9 @@ def run(cls, args): "model.layers." + str(i) + ".mlp.up_proj.weight" ].to(torch.float16) - eole_safetensor[ - "decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight" - ] = checkpoint["model.layers." + str(i) + ".post_attention_layernorm.weight"].to(torch.float16) + eole_safetensor["decoder.transformer_layers." + str(i) + ".feed_forward.layer_norm.weight"] = ( + checkpoint["model.layers." + str(i) + ".post_attention_layernorm.weight"].to(torch.float16) + ) if shard == 0: transformer_ff = eole_safetensor["decoder.transformer_layers.0.feed_forward.w_1.weight"].size(0) diff --git a/eole/config/models.py b/eole/config/models.py index 56441495e..efbac3c1d 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -328,12 +328,27 @@ class BaseModelConfig(Config): default_factory=EmbeddingsConfig, description="Contains most of the args useful to build the Embeddings module.", ) - encoder: (Union[TransformerEncoderConfig, RnnEncoderConfig, CnnEncoderConfig, MeanEncoderConfig,] | None) = Field( + encoder: ( + Union[ + TransformerEncoderConfig, + RnnEncoderConfig, + CnnEncoderConfig, + MeanEncoderConfig, + ] + | None + ) = Field( default=None, discriminator="encoder_type", description="Major parameters of an encoder.", ) # we shall use discriminators here - decoder: Union[TransformerDecoderConfig, RnnDecoderConfig, CnnDecoderConfig,] | None = Field( + decoder: ( + Union[ + TransformerDecoderConfig, + RnnDecoderConfig, + CnnDecoderConfig, + ] + | None + ) = Field( default=None, discriminator="decoder_type", description="Major parameters of a decoder.", diff --git a/eole/models/model.py b/eole/models/model.py index e46efa4e7..0b24b3896 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -504,7 +504,13 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset) else: row_slice_start = 0 row_slice_end = param.data.size(1) - assert param.data.size() == ckpt_t[col_slice_start:col_slice_end, row_slice_start:row_slice_end,].size(), ( + assert ( + param.data.size() + == ckpt_t[ + col_slice_start:col_slice_end, + row_slice_start:row_slice_end, + ].size() + ), ( "An error in model's partition and checkpoint's slice was detected, " f"[{name}, {module}, {param_name}, {param.data.size()}, {ckpt_t.size()}]" ) diff --git a/eole/tests/test_transform.py b/eole/tests/test_transform.py index b602fb75f..a5dd97f52 100644 --- a/eole/tests/test_transform.py +++ b/eole/tests/test_transform.py @@ -370,7 +370,9 @@ def test_pyonmttok_bpe(self): # Test mask location ex = { - "src": ("### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" "### Response : ⦅newline⦆response"), + "src": ( + "### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" "### Response : ⦅newline⦆response" + ), "tgt": "", } ex["src"] = ex["src"].split(" ") @@ -443,7 +445,9 @@ def test_pyonmttok_sp(self): # Test mask location ex = { - "src": ("### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" "### Response : ⦅newline⦆response"), + "src": ( + "### Instruction: ⦅newline⦆instruction⦅newline⦆⦅newline⦆" "### Response : ⦅newline⦆response" + ), "tgt": "", } ex["src"] = ex["src"].split(" ") From 5103ac466332f82010aaa33b91289c3e8ea09f87 Mon Sep 17 00:00:00 2001 From: vince62s Date: Wed, 8 Jan 2025 12:32:22 +0100 Subject: [PATCH 06/11] unused args for now --- eole/models/model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/eole/models/model.py b/eole/models/model.py index 0b24b3896..de06bb8d4 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -749,7 +749,6 @@ def forward(self, src, tgt, src_len, with_align=False): dec_out, attns = self.decoder( self.tgt_emb(dec_in), enc_out=enc_out, - src_len=src_len, with_align=with_align, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask, @@ -813,9 +812,6 @@ def forward(self, src, _, src_len, with_align=False): position_embeddings = self.rope.update(src.size(1), step=0) dec_out, attns = self.decoder( self.tgt_emb(src), - enc_out=None, - src_len=src_len, - with_align=with_align, tgt_pad_mask=src.eq(self.pad_idx).unsqueeze(1), position_embeddings=position_embeddings, ) From b8e77090015020e846f72792952b2d9bc5f42e3f Mon Sep 17 00:00:00 2001 From: vince62s Date: Thu, 9 Jan 2025 10:38:43 +0100 Subject: [PATCH 07/11] split position embed code from MHA --- eole/modules/multi_headed_attn.py | 233 +++---------------------- eole/modules/relative_position_bias.py | 118 +++++++++++++ eole/modules/rope.py | 73 ++++++-- 3 files changed, 200 insertions(+), 224 deletions(-) create mode 100644 eole/modules/relative_position_bias.py diff --git a/eole/modules/multi_headed_attn.py b/eole/modules/multi_headed_attn.py index 17655bf35..f69ac2320 100644 --- a/eole/modules/multi_headed_attn.py +++ b/eole/modules/multi_headed_attn.py @@ -2,194 +2,23 @@ import torch import torch.nn as nn -from math import log, sqrt +import torch.nn.functional as F +from math import sqrt from torch import Tensor from typing import Optional, Tuple -from torch.nn.functional import scaled_dot_product_attention from torch.utils.checkpoint import checkpoint from torch.nn.utils import skip_init -from .alibi_position_bias import AlibiPositionalBias from torch.distributed import all_reduce from importlib import import_module from eole.constants import PositionEncodingType - -# Help functions for Rotary Embeddings -# https://arxiv.org/pdf/2104.09864.pdf - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_emb(query, key, rope, interleave): - # now rope is a tuple (cos, sin) - cos, sin = rope - if interleave: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - query_ = query.float().reshape(*query.shape[:-1], -1, 2) - key_ = key.float().reshape(*key.shape[:-1], -1, 2) - - # Reshape cos and sin to match the dimensions of query_ and key_ - cos = cos[:, : cos.size(1) // 2].view(1, query_.size(1), 1, query_.size(3)) - sin = sin[:, : sin.size(1) // 2].view(1, key_.size(1), 1, key_.size(3)) - - query_rotated = query_[..., 0] * cos - query_[..., 1] * sin - query_rotated_imag = query_[..., 0] * sin + query_[..., 1] * cos - query_out = torch.stack((query_rotated, query_rotated_imag), dim=-1).flatten(3) - - key_rotated = key_[..., 0] * cos - key_[..., 1] * sin - key_rotated_imag = key_[..., 0] * sin + key_[..., 1] * cos - key_out = torch.stack((key_rotated, key_rotated_imag), dim=-1).flatten(3) - - return query_out.transpose(1, 2).type_as(query), key_out.transpose(1, 2).type_as(key) - - # Old code with complex instead - # rope_complex = torch.complex(cos, sin) - # query_ = torch.view_as_complex(query_) - # key_ = torch.view_as_complex(key_) - # query_out = torch.view_as_real(query_ * rope_complex).flatten(3) - # key_out = torch.view_as_real(key_ * rope_complex).flatten(3) - # return query_out.transpose(1, 2).type_as(query), key_out.transpose( - # 1, 2 - # ).type_as(key) - else: - rotary_dim = cos.size(1) - head_dim = query.size(3) - if rotary_dim < head_dim: - q_embed = (query[:, :, :, :rotary_dim] * cos) + (rotate_half(query[:, :, :, :rotary_dim]) * sin) - k_embed = (key[:, :, :, :rotary_dim] * cos) + (rotate_half(key[:, :, :, :rotary_dim]) * sin) - q_embed = torch.cat([q_embed, query[:, :, :, rotary_dim:]], dim=-1) - k_embed = torch.cat([k_embed, key[:, :, :, rotary_dim:]], dim=-1) - else: - q_embed = (query * cos) + (rotate_half(query) * sin) - k_embed = (key * cos) + (rotate_half(key) * sin) - return q_embed.type_as(query), k_embed.type_as(key) - - -# Help functions for "Relative" position encoding -# https://arxiv.org/abs/1803.02155 - - -def relative_matmul(x: Tensor, z: Tensor, transpose: bool) -> Tensor: - """ - Helper function for relative positions attention. - https://arxiv.org/pdf/1803.02155.pdf - x shape [batch_size x heads x q_len x k_len] - """ - batch_size, heads, length, _ = x.size() - x_t = x.permute(2, 0, 1, 3) - x_t_r = x_t.contiguous().view(length, heads * batch_size, -1) - if transpose: - z = z.transpose(1, 2) - x_tz_matmul = torch.matmul(x_t_r, z) - x_tz_matmul_r = x_tz_matmul.view(length, batch_size, heads, -1) - x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3) - return x_tz_matmul_r_t - - -def gen_relative_positions( - length: int, - n_positions: int, - cache: bool = False, - device: Optional[torch.device] = None, -) -> Tensor: - """Generate the clipped relative positions matrix - for a given length and maximum relative positions""" - if cache: - distance_mat = torch.arange(-length + 1, 1, 1, device=device).unsqueeze(0) - else: - range_vec = torch.arange(length, device=device) - range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1) - distance_mat = range_mat - range_mat.transpose(0, 1) - distance_mat_clipped = torch.clamp(distance_mat, min=-n_positions, max=n_positions) - # Shift values to be >= 0 - final_mat = distance_mat_clipped + n_positions - return final_mat - - -def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/ - mesh_tensorflow/transformer/transformer_layers.py#L593 - Translate relative position to a bucket number for relative attention. - The relative position is defined as memory_position - query_position, - i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. - We use smaller buckets for small absolute relative_position and larger buckets for - larger absolute relative_positions. All relative positions >=max_distance map to the - same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the - model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values - in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) - # now relative_position is in the range [0, inf) - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions - # up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) / log(max_distance / max_exact) * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), - ) - - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) - return relative_buckets - - -def compute_bias( - query_length, - key_length, - is_decoder, - n_positions, - relative_positions_buckets, - device=None, -): - """Compute binned relative position bias""" - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] - relative_position = memory_position - context_position # shape (query_length, key_length) - relative_position_bucket = _relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not is_decoder), - num_buckets=relative_positions_buckets, - max_distance=n_positions, - ) - return relative_position_bucket +from .relative_position_bias import relative_matmul, gen_relative_positions, compute_bias +from .alibi_position_bias import AlibiPositionalBias +from .rope import apply_rotary_emb # Help functions to split model dim per head - - def shape(x: Tensor, dim_per_head: int) -> Tensor: - """ - Projection. - [batchsize x length x modeldim] + """[batchsize x length x modeldim] -> [batchsize x heads x length x dimperhead] """ x_0, x_1, _ = x.size() @@ -197,9 +26,7 @@ def shape(x: Tensor, dim_per_head: int) -> Tensor: def unshape(x: Tensor) -> Tensor: - """ - Compute context. - [batchsize x heads x length x dimperhead] + """[batchsize x heads x length x dimperhead] -> [batchsize x length x modeldim] """ x_0, x_1, _, x_3 = x.size() @@ -245,12 +72,7 @@ class MultiHeadedAttention(torch.nn.Module): is_decoder: bool, true if called by the Decoder layers """ - def __init__( - self, - model_config, - running_config=None, - is_decoder: bool = True, - ) -> None: + def __init__(self, model_config, running_config=None, is_decoder: bool = True) -> None: super(MultiHeadedAttention, self).__init__() self.dim_per_head = model_config.dim_per_head self.heads = model_config.heads @@ -278,7 +100,6 @@ def __init__( out_features=self.dim_per_head * self.heads // self.parallel_gpu, bias=model_config.add_qkvbias, ) - self.softmax = nn.Softmax(dim=-1) self.dropout_p = getattr(running_config, "attention_dropout", [0.0])[0] self.dropout = nn.Dropout(self.dropout_p) @@ -302,10 +123,9 @@ def __init__( # TODO find a cleaner way to initialize? self.relative_positions_embeddings = None self.relative_attention_bias = None - self.rotary_interleave = None + self.rotary_interleave = None # for flash_kvcache without rotary if self.relative_positions_buckets > 0: self.relative_attention_bias = nn.Embedding(self.relative_positions_buckets, self.heads) - self.relative_positions_embeddings = None elif self.position_encoding_type == PositionEncodingType.Relative: # https://arxiv.org/pdf/1803.02155.pdf # in the paper they suggest either two embeds @@ -406,12 +226,10 @@ def _compute_attention( if self.heads_kv < self.heads: qh = query.size(1) # expand key on heads dimension when it's less than query heads (multi-query variant) - key = key.view(b, -1, 1, l, d).repeat(1, 1, qh // h, 1, 1) - key = key.view(b, qh, l, d) + key = key.view(b, -1, 1, l, d).repeat(1, 1, qh // h, 1, 1).view(b, qh, l, d) # expand value on heads dimension when it's less than query heads (multi-query variant) - value = value.view(b, -1, 1, l, d).repeat(1, 1, qh // h, 1, 1) - value = value.view(b, qh, l, d) + value = value.view(b, -1, 1, l, d).repeat(1, 1, qh // h, 1, 1).view(b, qh, l, d) # 2) When standard pos. enc. or rotary, use flash attention or SDPA if ( @@ -420,11 +238,11 @@ def _compute_attention( and query.device.type != "cpu" ): # Apply pytorch scaled_dot_product_attention. - attn_output = scaled_dot_product_attention( + attn_output = F.scaled_dot_product_attention( query, key, value, - attn_mask if attn_mask is not None else None, + attn_mask, self.dropout_p, is_causal=False, ) @@ -476,7 +294,7 @@ def _compute_attention( scores = scores.masked_fill(~attn_mask, -1e18) # 3) Apply attention dropout and compute context vectors. - attn = self.softmax(scores).to(query.dtype) + attn = F.softmax(scores, dim=-1).to(query.dtype) drop_attn = self.dropout(attn) if self.dropout_p > 0 else attn attn_output = torch.matmul(drop_attn, value) @@ -503,17 +321,12 @@ def _compute_attention( class SelfMHA(MultiHeadedAttention): - def __init__( - self, - model_config, - running_config=None, - is_decoder: bool = True, - ) -> None: + def __init__(self, model_config, running_config=None, is_decoder: bool = True) -> None: self.position_encoding_type = model_config.position_encoding_type self.n_positions = model_config.n_positions super(SelfMHA, self).__init__(model_config, running_config, is_decoder) - def _expand_cache(self, add_length, step): + def _expand_cache(self, add_length: int, step: int) -> int: b, h, l, dph = self.layer_cache[1]["keys"].shape if step >= l: ktype = self.layer_cache[1]["keys"].dtype @@ -540,13 +353,8 @@ def _expand_cache(self, add_length, step): return step def _prepare_inputs_w_cache( - self, - query, - key, - value, - step: Optional[int] = 0, - position_embeddings=None, - ): + self, query: Tensor, key: Tensor, value: Tensor, step: Optional[int] = 0, position_embeddings=None + ) -> Tuple[Tensor, Tensor, Tensor]: if self.position_encoding_type == PositionEncodingType.Rotary: seqlen = query.size(2) cos = position_embeddings[0][step : step + seqlen] @@ -611,8 +419,7 @@ def forward( cos = position_embeddings[0][:, :rotdim].to(query.dtype) sin = position_embeddings[1][:, :rotdim].to(query.dtype) else: - cos = None - sin = None + cos, sin = None, None context = self.flash_attn_with_kvcache( query.transpose(1, 2), self.layer_cache[1]["keys"].transpose(1, 2), @@ -654,7 +461,7 @@ def __init__( self.n_positions = 0 super(ContextMHA, self).__init__(model_config, running_config, True) - def _prepare_inputs_w_cache(self, key, value, query): + def _prepare_inputs_w_cache(self, key: Tensor, value: Tensor, query: Tensor) -> Tuple[Tensor, Tensor, Tensor]: query = self.linear_query(query) query = shape(query, self.dim_per_head) if self.layer_cache[1]["keys"].numel() == 0: diff --git a/eole/modules/relative_position_bias.py b/eole/modules/relative_position_bias.py new file mode 100644 index 000000000..c5a51785e --- /dev/null +++ b/eole/modules/relative_position_bias.py @@ -0,0 +1,118 @@ +""" Relative position bias """ + +import torch +from math import log +from torch import Tensor +from typing import Optional + + +# Help functions for "Relative" position encoding +def relative_matmul(x: Tensor, z: Tensor, transpose: bool) -> Tensor: + """ + Helper function for relative positions attention. + https://arxiv.org/pdf/1803.02155.pdf + x shape [batch_size x heads x q_len x k_len] + """ + batch_size, heads, length, _ = x.size() + x_t = x.permute(2, 0, 1, 3) + x_t_r = x_t.contiguous().view(length, heads * batch_size, -1) + if transpose: + z = z.transpose(1, 2) + x_tz_matmul = torch.matmul(x_t_r, z) + x_tz_matmul_r = x_tz_matmul.view(length, batch_size, heads, -1) + x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3) + return x_tz_matmul_r_t + + +def gen_relative_positions( + length: int, + n_positions: int, + cache: bool = False, + device: Optional[torch.device] = None, +) -> Tensor: + """Generate the clipped relative positions matrix + for a given length and maximum relative positions""" + if cache: + distance_mat = torch.arange(-length + 1, 1, 1, device=device).unsqueeze(0) + else: + range_vec = torch.arange(length, device=device) + range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1) + distance_mat = range_mat - range_mat.transpose(0, 1) + distance_mat_clipped = torch.clamp(distance_mat, min=-n_positions, max=n_positions) + # Shift values to be >= 0 + final_mat = distance_mat_clipped + n_positions + return final_mat + + +def _relative_position_bucket( + relative_position: Tensor, bidirectional: bool = True, num_buckets: int = 32, max_distance: int = 128 +) -> Tensor: + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/ + mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, + i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. + We use smaller buckets for small absolute relative_position and larger buckets for + larger absolute relative_positions. All relative positions >=max_distance map to the + same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the + model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values + in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions + # up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + +def compute_bias( + query_length: int, + key_length: int, + is_decoder: bool, + n_positions: int, + relative_positions_buckets: int, + device: Optional[torch.device] = None, +) -> Tensor: + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = _relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not is_decoder), + num_buckets=relative_positions_buckets, + max_distance=n_positions, + ) + return relative_position_bucket diff --git a/eole/modules/rope.py b/eole/modules/rope.py index b00a55008..21013f86c 100644 --- a/eole/modules/rope.py +++ b/eole/modules/rope.py @@ -1,6 +1,62 @@ import torch import torch.nn as nn import math +from torch import Tensor +from typing import Tuple + + +# Help functions for Rotary Embeddings +def rotate_half(x: Tensor) -> Tensor: + """Rotates half the hidden dims of the input.""" + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_emb( + query: Tensor, key: Tensor, rope: Tuple[Tensor, Tensor], interleave: bool +) -> Tuple[Tensor, Tensor]: + # now rope is a tuple (cos, sin) + cos, sin = rope + if interleave: + query, key = query.transpose(1, 2), key.transpose(1, 2) + query_ = query.float().reshape(*query.shape[:-1], -1, 2) + key_ = key.float().reshape(*key.shape[:-1], -1, 2) + + # Reshape cos and sin to match the dimensions of query_ and key_ + cos = cos[:, : cos.size(1) // 2].view(1, query_.size(1), 1, query_.size(3)) + sin = sin[:, : sin.size(1) // 2].view(1, key_.size(1), 1, key_.size(3)) + + query_rotated = query_[..., 0] * cos - query_[..., 1] * sin + query_rotated_imag = query_[..., 0] * sin + query_[..., 1] * cos + query_out = torch.stack((query_rotated, query_rotated_imag), dim=-1).flatten(3) + + key_rotated = key_[..., 0] * cos - key_[..., 1] * sin + key_rotated_imag = key_[..., 0] * sin + key_[..., 1] * cos + key_out = torch.stack((key_rotated, key_rotated_imag), dim=-1).flatten(3) + + return query_out.transpose(1, 2).type_as(query), key_out.transpose(1, 2).type_as(key) + + # Old code with complex instead + # rope_complex = torch.complex(cos, sin) + # query_ = torch.view_as_complex(query_) + # key_ = torch.view_as_complex(key_) + # query_out = torch.view_as_real(query_ * rope_complex).flatten(3) + # key_out = torch.view_as_real(key_ * rope_complex).flatten(3) + # return query_out.transpose(1, 2).type_as(query), key_out.transpose( + # 1, 2 + # ).type_as(key) + else: + rotary_dim = cos.size(1) + head_dim = query.size(3) + if rotary_dim < head_dim: + q_embed = (query[:, :, :, :rotary_dim] * cos) + (rotate_half(query[:, :, :, :rotary_dim]) * sin) + k_embed = (key[:, :, :, :rotary_dim] * cos) + (rotate_half(key[:, :, :, :rotary_dim]) * sin) + q_embed = torch.cat([q_embed, query[:, :, :, rotary_dim:]], dim=-1) + k_embed = torch.cat([k_embed, key[:, :, :, rotary_dim:]], dim=-1) + else: + q_embed = (query * cos) + (rotate_half(query) * sin) + k_embed = (key * cos) + (rotate_half(key) * sin) + return q_embed.type_as(query), k_embed.type_as(key) class RotaryPosition(nn.Module): @@ -47,14 +103,7 @@ def __init__(self, model_config): # TODO: extend with other scaling types if getattr(self.model_config.rope_config, "scaling_type", None) == "llama3": self.llama3_scaling() - tmax = torch.arange(1024) - rope = torch.outer(tmax, inv_freq) - cos = torch.cos(rope) - sin = torch.sin(rope) - cos = torch.cat((cos, cos), dim=-1) # Double the size by repeating `cos` - sin = torch.cat((sin, sin), dim=-1) # Double the size by repeating `sin` - self.register_buffer("cos", cos, persistent=False) - self.register_buffer("sin", sin, persistent=False) + self.update(1024) def llama3_scaling(self): """ @@ -111,7 +160,7 @@ def update(self, maxseqlen, step=0, prefetch=1024): step = 0 offset = 32 # make sure we have at least 32 positions for flash_attn_with_kvcache # This could probably a bit cleaner/homogenized with the offset case - if self.cos.size(0) >= max(offset + step, 0) + maxseqlen: + if hasattr(self, "cos") and self.cos.size(0) >= max(offset + step, 0) + maxseqlen: return self.cos, self.sin else: maxseqlen = maxseqlen + prefetch @@ -120,9 +169,11 @@ def update(self, maxseqlen, step=0, prefetch=1024): rope = torch.outer(tmax, self.inv_freq) cos = torch.cos(rope) sin = torch.sin(rope) - cos = torch.cat((cos, cos), dim=-1) # Double the size by repeating `cos` - sin = torch.cat((sin, sin), dim=-1) # Double the size by repeating `sin` + # cos = torch.cat((cos, cos), dim=-1) # Double the size by repeating `cos` + # sin = torch.cat((sin, sin), dim=-1) # Double the size by repeating `sin` + cos = torch.tile(cos, (1, 2)) + sin = torch.tile(sin, (1, 2)) self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) return cos, sin From 0d194833ae4d172f3ec9e251bc80625d73211fb4 Mon Sep 17 00:00:00 2001 From: vince62s Date: Thu, 9 Jan 2025 10:39:28 +0100 Subject: [PATCH 08/11] black --- eole/modules/multi_headed_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eole/modules/multi_headed_attn.py b/eole/modules/multi_headed_attn.py index f69ac2320..fd5abb07b 100644 --- a/eole/modules/multi_headed_attn.py +++ b/eole/modules/multi_headed_attn.py @@ -123,7 +123,7 @@ def __init__(self, model_config, running_config=None, is_decoder: bool = True) - # TODO find a cleaner way to initialize? self.relative_positions_embeddings = None self.relative_attention_bias = None - self.rotary_interleave = None # for flash_kvcache without rotary + self.rotary_interleave = None # for flash_kvcache without rotary if self.relative_positions_buckets > 0: self.relative_attention_bias = nn.Embedding(self.relative_positions_buckets, self.heads) elif self.position_encoding_type == PositionEncodingType.Relative: From 895d6b08f7268e6674c196799d165aa150c003d3 Mon Sep 17 00:00:00 2001 From: vince62s Date: Fri, 10 Jan 2025 16:16:50 +0100 Subject: [PATCH 09/11] revamp attention padding for better accuracy --- eole/decoders/transformer.py | 36 ++++++------ eole/encoders/transformer.py | 8 +-- eole/inputters/dynamic_iterator.py | 1 + eole/inputters/text_utils.py | 2 + eole/models/model.py | 1 + eole/modules/multi_headed_attn.py | 82 +++++++++++++++------------- eole/predict/generator.py | 9 +-- eole/predict/inference.py | 7 +++ eole/predict/translator.py | 1 - eole/tests/test_model_lm/config.json | 1 + 10 files changed, 84 insertions(+), 64 deletions(-) diff --git a/eole/decoders/transformer.py b/eole/decoders/transformer.py index aea056b09..f2d5945b8 100644 --- a/eole/decoders/transformer.py +++ b/eole/decoders/transformer.py @@ -84,8 +84,8 @@ def forward(self, layer_in, **kwargs): """ enc_out = kwargs.pop("enc_out", None) - src_pad_mask = kwargs.pop("src_pad_mask", None) attn_mask = kwargs.pop("attn_mask", None) + src_pad_mask = kwargs.pop("src_pad_mask", None) step = kwargs.pop("step", None) return_attn = kwargs.pop("return_attn", False) position_embeddings = kwargs.pop("position_embeddings", None) @@ -211,6 +211,8 @@ def init_state(self, **kwargs): pass def map_state(self, fn): + z = fn(self.left_pad_mask, 0) + self.left_pad_mask = z for layer in self.transformer_layers: if self.with_cross_attn: if layer.context_attn.layer_cache[1]["keys"].numel() != 0: @@ -220,15 +222,7 @@ def map_state(self, fn): if layer.self_attn.layer_cache[1]["keys"].numel() != 0: x = fn(layer.self_attn.layer_cache[1]["keys"], 0) y = fn(layer.self_attn.layer_cache[1]["values"], 0) - if layer.self_attn.layer_cache[1].get("key_pad_mask", None) is not None: - z = fn(layer.self_attn.layer_cache[1]["key_pad_mask"], 0) - else: - z = None - layer.self_attn.layer_cache = True, { - "keys": x, - "values": y, - "key_pad_mask": z, - } + layer.self_attn.layer_cache = True, {"keys": x, "values": y} def update_dropout(self, dropout, attention_dropout): for layer in self.transformer_layers: @@ -275,27 +269,35 @@ def forward(self, emb, **kwargs): src_pad_mask = kwargs.pop("src_pad_mask", None) if self.with_cross_attn: assert src_pad_mask is not None, "TransformerDecoder requires a src pad mask" + left_pad = kwargs.pop("left_pad", False) step = kwargs.pop("step", None) with_align = kwargs.pop("with_align", False) return_attn = with_align or kwargs.pop("return_attn", False) position_embeddings = kwargs.pop("position_embeddings", None) attn_aligns = [] + if step == 0: + self._enable_cache(emb.device, tgt_pad_mask) + if emb.size(1) > 1: - # causal masking is necessary when sequence length is greater than one + # training or first step decoding attn_mask = self._causal_attn_mask(tgt_pad_mask) if self.with_cross_attn: src_pad_mask = src_pad_mask.unsqueeze(1).expand(-1, -1, attn_mask.size(3), -1) # mask now are (batch x 1 x tlen x s or t len) # dim 1 to be broadcasted to heads in MHA else: - attn_mask = None + # step by step decoding + # we rebuild the pad mask of previous steps (left padded prompt) + if left_pad: + pad_mask = torch.zeros([emb.size(0), 1, step + 1], dtype=torch.bool, device=emb.device) + pad_mask[:, :, : self.left_pad_mask.size(2)] = self.left_pad_mask + attn_mask = ~pad_mask.unsqueeze(1) + else: + attn_mask = None if self.with_cross_attn: src_pad_mask = src_pad_mask.unsqueeze(1) - if step == 0: - self._enable_cache(emb.device, tgt_pad_mask) - for layer in self.transformer_layers: emb, attn = layer( emb, @@ -331,6 +333,7 @@ def forward(self, emb, **kwargs): return emb, attns def _enable_cache(self, device, pad_mask): + self.left_pad_mask = pad_mask for layer in self.transformer_layers: # first value set to True triggered by the beginning of decoding # layer_cache becomes active in the MultiHeadedAttention fwd @@ -339,7 +342,6 @@ def _enable_cache(self, device, pad_mask): { "keys": torch.tensor([], device=device), "values": torch.tensor([], device=device), - "key_pad_mask": pad_mask, }, ) if layer.context_attn: @@ -352,13 +354,13 @@ def _enable_cache(self, device, pad_mask): ) def _disable_cache(self): + self.left_pad_mask = torch.tensor([]) for layer in self.transformer_layers: layer.self_attn.layer_cache = ( False, { "keys": torch.tensor([]), "values": torch.tensor([]), - "key_pad_mask": None, }, ) if layer.context_attn: diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index 78cbda900..e6b9c88d6 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -122,19 +122,17 @@ def forward(self, emb, **kwargs): * enc_out ``(batch_size, src_len, model_dim)`` * encoder final state: None in the case of Transformer """ - pad_mask = kwargs.pop("pad_mask", None) assert pad_mask is not None, "TransformerEncoder requires a src pad mask" position_embeddings = kwargs.pop("position_embeddings", None) - enc_out = emb pad_mask = pad_mask.unsqueeze(1) # batch x 1 x 1 x maxlen pad_mask = pad_mask.expand(-1, -1, pad_mask.size(3), -1) # batch x 1 x maxlen x maxlen # 1 to be expanded to number of heads in MHA for layer in self.transformer_layers: - enc_out = layer(enc_out, pad_mask, position_embeddings=position_embeddings) - enc_out = self.layer_norm(enc_out) - return enc_out, None + emb = layer(emb, pad_mask, position_embeddings=position_embeddings) + emb = self.layer_norm(emb) + return emb, None def update_dropout(self, dropout, attention_dropout): for layer in self.transformer_layers: diff --git a/eole/inputters/dynamic_iterator.py b/eole/inputters/dynamic_iterator.py index d40dc215f..cb6c42177 100644 --- a/eole/inputters/dynamic_iterator.py +++ b/eole/inputters/dynamic_iterator.py @@ -391,6 +391,7 @@ def __iter__(self): "cid", "ind_in_bucket", "cid_line_number", + "left_pad", ]: tensor_batch[key] = tensor_batch[key].to(self.device) yield (tensor_batch, bucket_idx) diff --git a/eole/inputters/text_utils.py b/eole/inputters/text_utils.py index 6dae40c6d..b61f73978 100644 --- a/eole/inputters/text_utils.py +++ b/eole/inputters/text_utils.py @@ -252,4 +252,6 @@ def tensorify(vocabs, minibatch, device, left_pad=False): if minibatch[0][0]["cid"] != "infer": tensor_batch["sco"] = torch.tensor([ex["sco"] for ex, indice in minibatch], device=device) + tensor_batch["left_pad"] = left_pad + return tensor_batch diff --git a/eole/models/model.py b/eole/models/model.py index de06bb8d4..0a95d4afb 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -352,6 +352,7 @@ def load_test_model(cls, config, device_id=0, model_path=None): model, vocabs, model_config = get_model_class(config.model).from_config( config, vocabs, checkpoint, device_id, running_config=config ) + return vocabs, model, model_config def init_weights(self, running_config): diff --git a/eole/modules/multi_headed_attn.py b/eole/modules/multi_headed_attn.py index fd5abb07b..f49f6a9b5 100644 --- a/eole/modules/multi_headed_attn.py +++ b/eole/modules/multi_headed_attn.py @@ -244,7 +244,6 @@ def _compute_attention( value, attn_mask, self.dropout_p, - is_causal=False, ) attn = None else: @@ -326,34 +325,44 @@ def __init__(self, model_config, running_config=None, is_decoder: bool = True) - self.n_positions = model_config.n_positions super(SelfMHA, self).__init__(model_config, running_config, is_decoder) - def _expand_cache(self, add_length: int, step: int) -> int: - b, h, l, dph = self.layer_cache[1]["keys"].shape - if step >= l: - ktype = self.layer_cache[1]["keys"].dtype - kdev = self.layer_cache[1]["keys"].device - self.layer_cache[1]["keys"] = torch.cat( - ( - self.layer_cache[1]["keys"], - torch.zeros((b, h, add_length, dph), device=kdev, dtype=ktype), - ), - dim=2, - ) - self.layer_cache[1]["values"] = torch.cat( - ( - self.layer_cache[1]["values"], - torch.zeros((b, h, add_length, dph), device=kdev, dtype=ktype), - ), - dim=2, - ) + def _expand_cache(self, add_length: int, step: int, key: Tensor) -> int: + if step == 0: + self.layer_cache[1]["keys"] = key.new_zeros(key.shape) + self.layer_cache[1]["values"] = key.new_zeros(key.shape) + else: + b, h, l, dph = self.layer_cache[1]["keys"].shape + if step >= l: + ktype = self.layer_cache[1]["keys"].dtype + kdev = self.layer_cache[1]["keys"].device + self.layer_cache[1]["keys"] = torch.cat( + ( + self.layer_cache[1]["keys"], + torch.zeros((b, h, add_length, dph), device=kdev, dtype=ktype), + ), + dim=2, + ) + self.layer_cache[1]["values"] = torch.cat( + ( + self.layer_cache[1]["values"], + torch.zeros((b, h, add_length, dph), device=kdev, dtype=ktype), + ), + dim=2, + ) if self.sliding_window > 0 and l > self.sliding_window: self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][:, :, 1:, :] self.layer_cache[1]["values"] = self.layer_cache[1]["values"][:, :, 1:, :] return self.sliding_window else: - return step + return step + 1 def _prepare_inputs_w_cache( - self, query: Tensor, key: Tensor, value: Tensor, step: Optional[int] = 0, position_embeddings=None + self, + query: Tensor, + key: Tensor, + value: Tensor, + step: Optional[int] = 0, + attn_mask: Optional[Tensor] = None, + position_embeddings=None, ) -> Tuple[Tensor, Tensor, Tensor]: if self.position_encoding_type == PositionEncodingType.Rotary: seqlen = query.size(2) @@ -362,23 +371,17 @@ def _prepare_inputs_w_cache( query, key = apply_rotary_emb(query, key, (cos, sin), interleave=self.rotary_interleave) if step == 0: - # mask keys for LM left padding by batch - key_pad_mask = self.layer_cache[1].get("key_pad_mask", None) - if key_pad_mask is not None: - x = key_pad_mask.expand(-1, key.size(1), -1) - x = x.unsqueeze(3).expand(-1, -1, -1, key.size(3)) - key = key.masked_fill(x, 0) # init cache with initial key, value self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value return key, value, query else: - cache_len = self._expand_cache(32, step) - self.layer_cache[1]["keys"][:, :, cache_len, :] = key[:, :, 0, :] - self.layer_cache[1]["values"][:, :, cache_len, :] = value[:, :, 0, :] + cache_len = self._expand_cache(32, step, key) + self.layer_cache[1]["keys"][:, :, cache_len - 1, :] = key[:, :, 0, :] + self.layer_cache[1]["values"][:, :, cache_len - 1, :] = value[:, :, 0, :] return ( - self.layer_cache[1]["keys"][:, :, : cache_len + 1, :], - self.layer_cache[1]["values"][:, :, : cache_len + 1, :], + self.layer_cache[1]["keys"][:, :, :cache_len, :], + self.layer_cache[1]["values"][:, :, :cache_len, :], query, ) @@ -398,9 +401,9 @@ def forward( key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) query = shape(query, self.dim_per_head) + if ( - step == 0 - or not self.flash + not self.flash or self.position_encoding_type in [PositionEncodingType.Relative, PositionEncodingType.Alibi] or query.dtype not in [torch.float16, torch.bfloat16] # to match with flash ): @@ -409,17 +412,20 @@ def forward( key, value, step=step, + attn_mask=attn_mask, position_embeddings=position_embeddings, ) else: # Fast path with flash_attn_with_kvcache - cache_len = self._expand_cache(32, step) + cache_len = self._expand_cache(32, step, key) if position_embeddings is not None: rotdim = self.rotary_dim // 2 cos = position_embeddings[0][:, :rotdim].to(query.dtype) sin = position_embeddings[1][:, :rotdim].to(query.dtype) else: cos, sin = None, None + # restore initial tgt_pad_mask - migth be better to store it instead. + cache_leftpad = (~torch.any(attn_mask, dim=-2).squeeze(1)).sum(dim=1).to(torch.int32) context = self.flash_attn_with_kvcache( query.transpose(1, 2), self.layer_cache[1]["keys"].transpose(1, 2), @@ -428,7 +434,9 @@ def forward( value.transpose(1, 2), rotary_cos=cos.contiguous(), rotary_sin=sin.contiguous(), - cache_seqlens=cache_len, + cache_seqlens=cache_len - 1, + cache_leftpad=cache_leftpad, + causal=step == 0, rotary_interleaved=self.rotary_interleave, ).transpose(1, 2) attn_output = self.final_linear(unshape(context)) diff --git a/eole/predict/generator.py b/eole/predict/generator.py index 6e57dd6f8..69c83cc78 100644 --- a/eole/predict/generator.py +++ b/eole/predict/generator.py @@ -78,8 +78,8 @@ def split_src_to_prevent_padding(cls, src, src_len): min_len_batch = torch.min(src_len).item() target_prefix = None if min_len_batch > 0 and min_len_batch < src.size(1): - target_prefix = src[:, min_len_batch:, :] - src = src[:, :min_len_batch, :] + target_prefix = src[:, min_len_batch:] + src = src[:, :min_len_batch] src_len[:] = min_len_batch return src, src_len, target_prefix @@ -90,7 +90,7 @@ def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs): log_probs = log_probs[:, -1, :] return log_probs - def _predict_batch_with_strategy(self, batch, decode_strategy, left_pad=True): + def _predict_batch_with_strategy(self, batch, decode_strategy): """Predict a batch of sentences step by step using cache. Args: @@ -109,7 +109,7 @@ def _predict_batch_with_strategy(self, batch, decode_strategy, left_pad=True): src = batch["src"] src_len = batch["srclen"] - if left_pad: + if batch["left_pad"]: target_prefix = None else: src, src_len, target_prefix = self.split_src_to_prevent_padding(src, src_len) @@ -134,6 +134,7 @@ def _predict_batch_with_strategy(self, batch, decode_strategy, left_pad=True): None, src_len=decode_strategy.src_len, step=step if step == 0 else step + max(src_len.tolist()), + left_pad=batch["left_pad"], ) if step == 0: diff --git a/eole/predict/inference.py b/eole/predict/inference.py index 3145566da..89792bfe3 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -163,6 +163,11 @@ def __init__( self.add_estimator = add_estimator self.id_tokenization = id_tokenization + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_cudnn_sdp(False) + @classmethod def from_config( cls, @@ -606,6 +611,7 @@ def _decode_and_generate( src_len, step=None, return_attn=False, + left_pad=False, ): # Decoder forward, takes [batch, tgt_len, nfeats] as input @@ -632,6 +638,7 @@ def _decode_and_generate( src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask, position_embeddings=position_embeddings, + left_pad=left_pad, ) # Generator forward. if "std" in dec_attn: diff --git a/eole/predict/translator.py b/eole/predict/translator.py index 37bc64cd9..4cd7c9039 100644 --- a/eole/predict/translator.py +++ b/eole/predict/translator.py @@ -162,7 +162,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): # (1) Run the encoder on the src. src, enc_final_hs, enc_out, src_len = self._run_encoder(batch) - self.model.decoder.init_state(src=src, enc_out=enc_out, enc_final_hs=enc_final_hs) gold_score, gold_log_probs = self._gold_score( diff --git a/eole/tests/test_model_lm/config.json b/eole/tests/test_model_lm/config.json index 71faf629e..0cf7c70ea 100644 --- a/eole/tests/test_model_lm/config.json +++ b/eole/tests/test_model_lm/config.json @@ -11,6 +11,7 @@ "heads": 2 }, "architecture": "custom", + "left_pad": false, "encoder": null, "embeddings": { "tgt_word_vec_size": 64, From 57b8943be4fcc1144acb3f1de8f4093ffd2ca7f1 Mon Sep 17 00:00:00 2001 From: vince62s Date: Fri, 10 Jan 2025 16:55:46 +0100 Subject: [PATCH 10/11] clean --- eole/bin/run/predict.py | 5 +++++ eole/bin/run/train.py | 5 +++++ eole/decoders/transformer.py | 12 ++++-------- eole/encoders/transformer.py | 3 +-- eole/modules/rope.py | 6 ++---- eole/predict/inference.py | 5 ----- eole/train_single.py | 6 ------ 7 files changed, 17 insertions(+), 25 deletions(-) diff --git a/eole/bin/run/predict.py b/eole/bin/run/predict.py index 637e3be1b..241188d60 100644 --- a/eole/bin/run/predict.py +++ b/eole/bin/run/predict.py @@ -1,5 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import torch from eole.inference_engine import InferenceEnginePY, InferenceEngineCT2 from eole.constants import ModelType from argparse import ArgumentParser @@ -36,6 +37,10 @@ def predict(config): class Predict(RunBin): config_class = PredictConfig require_config = False + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_cudnn_sdp(False) @classmethod def run(cls, args): diff --git a/eole/bin/run/train.py b/eole/bin/run/train.py index d618773a9..69f3376b2 100644 --- a/eole/bin/run/train.py +++ b/eole/bin/run/train.py @@ -24,6 +24,11 @@ def train(config): config._validate_vocab_config() init_logger(config.log_file) set_random_seed(config.seed, False) + # Allow only Memory Efficient path for sdpa + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_cudnn_sdp(False) train_process = partial(single_main) diff --git a/eole/decoders/transformer.py b/eole/decoders/transformer.py index f2d5945b8..d9a2445d4 100644 --- a/eole/decoders/transformer.py +++ b/eole/decoders/transformer.py @@ -238,8 +238,8 @@ def _causal_attn_mask(self, tgt_pad_mask): if self.sliding_window > 0: future_mask = future_mask.triu_(-self.sliding_window) attn_mask = ~tgt_pad_mask & future_mask.unsqueeze(0) - attn_mask = attn_mask.unsqueeze(1) - attn_mask = attn_mask.expand(-1, -1, attn_mask.size(3), -1) + attn_mask = attn_mask.unsqueeze(1) # (batch x 1 x 1 x tgt_len) + # dim 1 (heads) and 2 (tgt_len) will be broadcasted automatically in MHA return attn_mask def forward(self, emb, **kwargs): @@ -269,6 +269,8 @@ def forward(self, emb, **kwargs): src_pad_mask = kwargs.pop("src_pad_mask", None) if self.with_cross_attn: assert src_pad_mask is not None, "TransformerDecoder requires a src pad mask" + src_pad_mask = src_pad_mask.unsqueeze(1) # (batch x 1 x 1 x src_len) + # dim 1 (heads) and 2 (tgt_len) will be broadcasted automatically in MHA left_pad = kwargs.pop("left_pad", False) step = kwargs.pop("step", None) with_align = kwargs.pop("with_align", False) @@ -282,10 +284,6 @@ def forward(self, emb, **kwargs): if emb.size(1) > 1: # training or first step decoding attn_mask = self._causal_attn_mask(tgt_pad_mask) - if self.with_cross_attn: - src_pad_mask = src_pad_mask.unsqueeze(1).expand(-1, -1, attn_mask.size(3), -1) - # mask now are (batch x 1 x tlen x s or t len) - # dim 1 to be broadcasted to heads in MHA else: # step by step decoding # we rebuild the pad mask of previous steps (left padded prompt) @@ -295,8 +293,6 @@ def forward(self, emb, **kwargs): attn_mask = ~pad_mask.unsqueeze(1) else: attn_mask = None - if self.with_cross_attn: - src_pad_mask = src_pad_mask.unsqueeze(1) for layer in self.transformer_layers: emb, attn = layer( diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index e6b9c88d6..8e18ce96a 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -126,8 +126,7 @@ def forward(self, emb, **kwargs): assert pad_mask is not None, "TransformerEncoder requires a src pad mask" position_embeddings = kwargs.pop("position_embeddings", None) pad_mask = pad_mask.unsqueeze(1) # batch x 1 x 1 x maxlen - pad_mask = pad_mask.expand(-1, -1, pad_mask.size(3), -1) # batch x 1 x maxlen x maxlen - # 1 to be expanded to number of heads in MHA + # dim 1 (heads) and 2 (src_len) will be broadcasted automatically in MHA for layer in self.transformer_layers: emb = layer(emb, pad_mask, position_embeddings=position_embeddings) diff --git a/eole/modules/rope.py b/eole/modules/rope.py index 21013f86c..4d2289a2e 100644 --- a/eole/modules/rope.py +++ b/eole/modules/rope.py @@ -169,11 +169,9 @@ def update(self, maxseqlen, step=0, prefetch=1024): rope = torch.outer(tmax, self.inv_freq) cos = torch.cos(rope) sin = torch.sin(rope) - # cos = torch.cat((cos, cos), dim=-1) # Double the size by repeating `cos` - # sin = torch.cat((sin, sin), dim=-1) # Double the size by repeating `sin` + cos = torch.cat((cos, cos), dim=-1) # Double the size by repeating `cos` + sin = torch.cat((sin, sin), dim=-1) # Double the size by repeating `sin` - cos = torch.tile(cos, (1, 2)) - sin = torch.tile(sin, (1, 2)) self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) return cos, sin diff --git a/eole/predict/inference.py b/eole/predict/inference.py index 89792bfe3..b48db1a7a 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -163,11 +163,6 @@ def __init__( self.add_estimator = add_estimator self.id_tokenization = id_tokenization - torch.backends.cuda.enable_mem_efficient_sdp(True) - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_math_sdp(False) - torch.backends.cuda.enable_cudnn_sdp(False) - @classmethod def from_config( cls, diff --git a/eole/train_single.py b/eole/train_single.py index 41ca2fadd..518bb4909 100644 --- a/eole/train_single.py +++ b/eole/train_single.py @@ -134,12 +134,6 @@ def main(config, device_id): init_logger(config.log_file) checkpoint, vocabs, transforms, config = _init_train(config) - # Allow only Memory Efficient path for sdpa - torch.backends.cuda.enable_mem_efficient_sdp(True) - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_math_sdp(False) - torch.backends.cuda.enable_cudnn_sdp(False) - # if transform + options set in 'valid' we need to copy in main # transform / options for scoring considered as inference validset_transforms = getattr(config.data.get("valid", None), "transforms", None) From fe8f63c6910a9eefcc048666bf4a3a0661b48893 Mon Sep 17 00:00:00 2001 From: vince62s Date: Mon, 13 Jan 2025 09:17:22 +0100 Subject: [PATCH 11/11] adjust to comments --- eole/bin/run/predict.py | 5 ----- eole/bin/run/train.py | 5 ----- eole/inference_engine.py | 4 ++++ eole/train_single.py | 6 ++++++ 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/eole/bin/run/predict.py b/eole/bin/run/predict.py index 241188d60..637e3be1b 100644 --- a/eole/bin/run/predict.py +++ b/eole/bin/run/predict.py @@ -1,6 +1,5 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import torch from eole.inference_engine import InferenceEnginePY, InferenceEngineCT2 from eole.constants import ModelType from argparse import ArgumentParser @@ -37,10 +36,6 @@ def predict(config): class Predict(RunBin): config_class = PredictConfig require_config = False - torch.backends.cuda.enable_mem_efficient_sdp(True) - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_math_sdp(False) - torch.backends.cuda.enable_cudnn_sdp(False) @classmethod def run(cls, args): diff --git a/eole/bin/run/train.py b/eole/bin/run/train.py index 69f3376b2..d618773a9 100644 --- a/eole/bin/run/train.py +++ b/eole/bin/run/train.py @@ -24,11 +24,6 @@ def train(config): config._validate_vocab_config() init_logger(config.log_file) set_random_seed(config.seed, False) - # Allow only Memory Efficient path for sdpa - torch.backends.cuda.enable_mem_efficient_sdp(True) - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_math_sdp(False) - torch.backends.cuda.enable_cudnn_sdp(False) train_process = partial(single_main) diff --git a/eole/inference_engine.py b/eole/inference_engine.py index 7b36b2b56..c7940340f 100755 --- a/eole/inference_engine.py +++ b/eole/inference_engine.py @@ -21,6 +21,10 @@ class InferenceEngine(object): def __init__(self, config): self.config = config # PredictConfig self.model_type = None + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_cudnn_sdp(False) def predict_batch(self, batch): pass diff --git a/eole/train_single.py b/eole/train_single.py index 518bb4909..41ca2fadd 100644 --- a/eole/train_single.py +++ b/eole/train_single.py @@ -134,6 +134,12 @@ def main(config, device_id): init_logger(config.log_file) checkpoint, vocabs, transforms, config = _init_train(config) + # Allow only Memory Efficient path for sdpa + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_cudnn_sdp(False) + # if transform + options set in 'valid' we need to copy in main # transform / options for scoring considered as inference validset_transforms = getattr(config.data.get("valid", None), "transforms", None)