diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index defcb0b92..3d62ed630 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -25,7 +25,7 @@ jobs: pip install sacrebleu pip install flake8 pip install rich - python -m pip install black flake8 pyproject-flake8 + python -m pip install "black==24.10.0" "flake8<7.1" pyproject-flake8 if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Check code with Black run: | diff --git a/eole/bin/convert/convert_HF.py b/eole/bin/convert/convert_HF.py index 0886d57cb..36e8d1fd6 100755 --- a/eole/bin/convert/convert_HF.py +++ b/eole/bin/convert/convert_HF.py @@ -25,6 +25,7 @@ from eole.config.models import ( TransformerEncoderModelConfig, TransformerLMModelConfig, + VisionTransformerLMModelConfig, ) from eole.config.run import TrainConfig from eole.config.training import TrainingConfig @@ -122,6 +123,33 @@ "encoder.layer_norm.weight": "roberta.encoder.LayerNorm.weight", "encoder.layer_norm.bias": "roberta.encoder.LayerNorm.bias", }, + "LlavaForConditionalGeneration": { + "decoder_layer_prefix": "language_model.model.layers.", + "tgt_emb.embeddings.weight": "language_model.model.embed_tokens.weight", + "decoder.layer_norm.weight": "language_model.model.norm.weight", + "generator.weight": "language_model.lm_head.weight", + "encoder.patch_conv.weight": "vision_tower.patch_conv.weight", + "encoder.ln_pre.weight": "vision_tower.ln_pre.weight", + # vision_tower + "encoder_layer_prefix": "vision_tower.transformer.layers.", + "encoder": { + "layers": 24, + ".self_attn.linear_query.": ".attention.q_proj.", + ".self_attn.linear_keys.": ".attention.k_proj.", + ".self_attn.linear_values.": ".attention.v_proj.", + ".self_attn.final_linear.": ".attention.o_proj.", + ".mlp.gate_up_proj.": ".feed_forward.gate_proj.", + ".mlp.down_proj.": ".feed_forward.down_proj.", + ".mlp.up_proj.": ".feed_forward.up_proj.", + ".input_layernorm.weight": ".attention_norm.weight", # not sure about this one + ".post_attention_layernorm.weight": ".ffn_norm.weight", + }, + # vision_adapter + "adapter.w_in.weight": "multi_modal_projector.linear_1.weight", + "adapter.w_in.bias": "multi_modal_projector.linear_1.bias", + "adapter.w_out.weight": "multi_modal_projector.linear_2.weight", + "adapter.w_out.bias": "multi_modal_projector.linear_2.bias", + }, } # Combine base mappings with overrides @@ -152,7 +180,10 @@ # Eole config class ARCH_TABLE = defaultdict( lambda: TransformerLMModelConfig, - {"XLMRobertaXLForMaskedLM": TransformerEncoderModelConfig}, + { + "XLMRobertaXLForMaskedLM": TransformerEncoderModelConfig, + "LlavaForConditionalGeneration": VisionTransformerLMModelConfig, + }, ) # Default tokenization transform @@ -284,6 +315,11 @@ def __getattr__(self, name): def arch(self): return self.config["architectures"][0] + @property + def vocab_size(self): + config = self.config.get("text_config", self.config) + return config["vocab_size"] + @property def encoder_layer_prefix(self): return KEY_MAPS[self.arch].get("encoder_layer_prefix", None) @@ -381,14 +417,19 @@ def build_config_dict(hf): config = hf.config arch = hf.arch + vision_config = config.get("vision_config", None) + config = config.get("text_config", config) + model_config = {} training_config = {} # Initialize model_config with defaults and fallbacks model_config = { - "layers": config.get("num_hidden_layers", config.get("n_layer")), - "hidden_size": config.get("hidden_size", config.get("n_embd")), - "heads": config.get("num_attention_heads", config.get("n_head")), + "layers": config.get("num_hidden_layers", config.get("n_layer", config.get("n_layers"))), + "hidden_size": config.get("hidden_size", config.get("n_embd", config.get("hidden_dim"))), + "heads": config.get( + "num_attention_heads", config.get("n_head", config.get("n_heads", 32)) + ), # default 32 patch for mistral-community/pixtral-12b "transformer_ff": config.get("intermediate_size", config.get("hidden_size", config.get("n_embd")) * 4), "mlp_activation_fn": ACT_TABLE[arch], "layer_norm": LN_TABLE[arch], @@ -561,6 +602,29 @@ def build_config_dict(hf): }, } + # Vision encoder + if arch == "LlavaForConditionalGeneration": + # TODO: extend to other Llava models (with CLIP vision encoder) + model_config["encoder"] = { + "mlp_activation_fn": model_config["mlp_activation_fn"], + "layer_norm": model_config["layer_norm"], + "norm_eps": model_config["norm_eps"], + "hidden_size": vision_config["image_size"], + "transformer_ff": vision_config["image_size"] * 4, # hard-coded for mistral-community/pixtral-12b + "num_channels": 3, + "image_size": vision_config["image_size"], + "patch_size": vision_config["patch_size"], + "rope_config": { + "rotary_theta": vision_config["rope_theta"], + "rotary_interleave": False, + }, + "layers": 24, # hard-coded for mistral-community/pixtral-12b + "heads": vision_config["image_size"] / vision_config["head_dim"], + "heads_kv": vision_config["image_size"] / vision_config["head_dim"], + "head_dim": vision_config["head_dim"], + "image_token_id": 10, + } + # Update model_config based on architecture if arch in arch_configs: for key, value in arch_configs[arch].items(): @@ -673,7 +737,7 @@ def get_shards_map(model_config, hf, nshards): # Check if a layer key belongs to the current shard def is_layer_in_range(key, prefix, layer_range): - return prefix and key.startswith(prefix) and int(key.split(".")[2]) in layer_range + return prefix and key.startswith(prefix) and int(key.split(".")[len(prefix.split(".")) - 1]) in layer_range # Loop over the weightmap and distribute checkpoints to the appropriate shards for key, ckpt in weightmap.items(): @@ -716,6 +780,12 @@ def build_shards(model_config, hf, args, params): "encoder.layer_norm.weight", "encoder.layer_norm.bias", "generator.weight", + "encoder.patch_conv.weight", + "encoder.ln_pre.weight", + "adapter.w_in.weight", + "adapter.w_in.bias", + "adapter.w_out.weight", + "adapter.w_out.bias", ] def build_first_shard(hf, eole_safetensor): @@ -754,26 +824,30 @@ def build_first_shard(hf, eole_safetensor): if hf_prefix is None: continue for param in params: - for target, source in KEY_MAPS[hf.arch].items(): - if target in first_shard_targets: - continue - srckey, srcmap = source if isinstance(source, tuple) else (source, None) - w = get_weight( - checkpoint, - hf_prefix + str(i) + srckey + param, - ) - - if w is not None: - if srcmap is not None: - w = eval( - "w" + srcmap, - { - "w": w, - "hidden_size": model_config["hidden_size"], - "transformer_ff": model_config["transformer_ff"], - }, - ).contiguous() - eole_safetensor[eole_prefix + str(i) + target + param] = w + # TODO: factorize this better + for key_map in [KEY_MAPS[hf.arch], KEY_MAPS[hf.arch].get("encoder", {})]: + for target, source in key_map.items(): + if not isinstance(source, str): + continue + if target in first_shard_targets: + continue + srckey, srcmap = source if isinstance(source, tuple) else (source, None) + w = get_weight( + checkpoint, + hf_prefix + str(i) + srckey + param, + ) + + if w is not None: + if srcmap is not None: + w = eval( + "w" + srcmap, + { + "w": w, + "hidden_size": model_config["hidden_size"], + "transformer_ff": model_config["transformer_ff"], + }, + ).contiguous() + eole_safetensor[eole_prefix + str(i) + target + param] = w if model_config["shared_layer_norm"]: idx = 0 @@ -789,26 +863,30 @@ def build_first_shard(hf, eole_safetensor): "post_feedforward_layernorm", "mlp.gate", ]: + if hf_prefix == hf.encoder_layer_prefix: + source_map = KEY_MAPS[hf.arch]["encoder"] + else: + source_map = KEY_MAPS[hf.arch] module_p = f".{module}.{p}" - if module_p in KEY_MAPS[hf.arch].keys(): - if isinstance(KEY_MAPS[hf.arch][module_p], tuple): + if module_p in source_map.keys(): + if isinstance(source_map[module_p], tuple): w = get_weight( checkpoint, - hf_prefix + str(i) + KEY_MAPS[hf.arch][module_p][idx], + hf_prefix + str(i) + source_map[module_p][idx], ) else: w = get_weight( checkpoint, - hf_prefix + str(i) + KEY_MAPS[hf.arch][module_p], + hf_prefix + str(i) + source_map[module_p], ) if w is not None: eole_safetensor[eole_prefix + str(i) + module_p] = w for j in range(model_config["num_experts"]): - if f".mlp.experts.{j}.layer_norm." + p in KEY_MAPS[hf.arch].keys(): + if f".mlp.experts.{j}.layer_norm." + p in source_map.keys(): w = get_weight( checkpoint, - hf_prefix + str(i) + KEY_MAPS[hf.arch][f".mlp.experts.{j}.layer_norm." + p], + hf_prefix + str(i) + source_map[f".mlp.experts.{j}.layer_norm." + p], ) if w is not None: eole_safetensor[eole_prefix + str(i) + f".mlp.experts.{j}.layer_norm." + p] = w @@ -840,7 +918,8 @@ def check_sentencepiece_tokenizer(hf): def check_bpe_tokenizer(hf, vocabs, directory_path): - vocab_size = hf.config["vocab_size"] + config = hf.config.get("text_config", hf.config) + vocab_size = hf.vocab_size # gpt2_pretok pretokenizers = hf.tokenizer.get("pre_tokenizer", {}).get("pretokenizers", [{}]) pre_tokenizer = hf.tokenizer.get("pre_tokenizer", None) @@ -866,8 +945,8 @@ def check_bpe_tokenizer(hf, vocabs, directory_path): src_vocab = pyonmttok.build_vocab_from_tokens(vocab) # TODO: not sure for which model(s) this is needed for token_name in ["bos_token", "unk_token", "eos_token", "pad_token"]: - if f"{token_name}_id" in hf.config.keys(): - token = hf.config[f"{token_name}_id"] + if f"{token_name}_id" in config.keys(): + token = config[f"{token_name}_id"] if isinstance(token, list): vocabs["specials"][token_name] = vocab[token[0]] elif isinstance(token, str): @@ -1009,8 +1088,8 @@ def run(cls, args): src_vocab=None, tgt_vocab=None, share_vocab=True, - src_vocab_size=hf.config["vocab_size"], - tgt_vocab_size=hf.config["vocab_size"], + src_vocab_size=hf.vocab_size, + tgt_vocab_size=hf.vocab_size, vocab_size_multiple=8, decoder_start_token=vocabs["decoder_start_token"], **vocabs["specials"], diff --git a/eole/config/data.py b/eole/config/data.py index ecfc35a34..d07ac6d0c 100644 --- a/eole/config/data.py +++ b/eole/config/data.py @@ -232,13 +232,13 @@ def _validate_data(self): logger.info(f"Missing transforms field for {cname} data, " f"set to default: {default_transforms}.") corpus.transforms = default_transforms # Check path - if corpus.path_src is None: + if corpus.path_src is None and corpus.path_txt is None: raise ValueError( - f"Corpus {cname} src path is required." + f"Corpus {cname} `path_src` or `path_tgt` is required." "tgt path is also required for non language" " modeling tasks." ) - else: + elif corpus.path_src is not None: self.__class__._validate_file(corpus.path_src, info=f"{cname}/path_src") if corpus.path_tgt is None: logger.debug("path_tgt is None, it should be set unless the task" " is language modeling") diff --git a/eole/config/models.py b/eole/config/models.py index e60275ab0..cb2d240bd 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -333,6 +333,19 @@ def _validate_transformer_decoder_config(self): return self +class VisionEncoderConfig(TransformerConfig, EncoderConfig): + """ + Based on mistral-community/pixtral-12b, might evolve later. + """ + + encoder_type: Literal["vision"] = Field(default="vision") + # default to Pixtral 12B settings, might change later + num_channels: int | None = 3 + image_size: int | None = 1024 + patch_size: int | None = 16 + image_token_id: int | None = 10 + + # use Field with default= + description would be more readable # was inheriting from VocabConfig, but removed for now to facilitate inference tests # could we have different BaseModelConfig classes (inheriting from a base one) @@ -349,6 +362,7 @@ class BaseModelConfig(Config): RnnEncoderConfig, CnnEncoderConfig, MeanEncoderConfig, + VisionEncoderConfig, ] | None ) = Field( @@ -478,7 +492,7 @@ def update_model_opts(self): if getattr(self.encoder, "encoder_type", None) == "brnn" and self.decoder.decoder_type == "rnn": update_dict["decoder"] = {"bidirectional_encoder": True} - if self.encoder is not None: + if self.encoder is not None and hasattr(self.encoder, "src_word_vec_size"): update_dict["encoder"] = {"src_word_vec_size": self.embeddings.src_word_vec_size} if getattr(self.encoder, "encoder_type", None) == "transformer": update_dict["encoder"].update( @@ -520,11 +534,11 @@ def _override_values(self): encoder_fields = {} decoder_fields = {} for field in self.model_fields_set: - if hasattr(self.embeddings, field): + if hasattr(self.embeddings, field) and field not in self.embeddings.model_fields_set: embeddings_fields[field] = getattr(self, field) - if hasattr(self.encoder, field): + if hasattr(self.encoder, field) and field not in self.encoder.model_fields_set: encoder_fields[field] = getattr(self, field) - if hasattr(self.decoder, field): + if hasattr(self.decoder, field) and field not in self.decoder.model_fields_set: decoder_fields[field] = getattr(self, field) if self.embeddings is not None: self.embeddings.update(**embeddings_fields) @@ -540,8 +554,11 @@ def _validate_model_config(self): # encoder and decoder should be same sizes if self.encoder is not None and self.decoder is not None: - same_size = self.encoder.hidden_size == self.decoder.hidden_size - assert same_size, "The encoder and decoder rnns must be the same size for now" + if isinstance(self.encoder, VisionEncoderConfig): + pass + else: + same_size = self.encoder.hidden_size == self.decoder.hidden_size + assert same_size, "The encoder and decoder must have the same hidden size for seq2seq models." if self.share_embeddings: if self.encoder is None or self.decoder is None: @@ -691,6 +708,48 @@ def _validate_transformer(self): return self +class VisionTransformerLMModelConfig(TransformerConfig, BaseModelConfig): + architecture: Literal["vision_transformer_lm"] = Field(default="vision_transformer_lm") + + @model_validator(mode="before") + @classmethod + def encoder_decoder_type(cls, data: Any) -> Any: + # patch to allow transparent setting of encoder/decoder_type + if not (isinstance(data, dict)): + return data + if "encoder" in data.keys(): + data["encoder"]["encoder_type"] = "vision" + else: + data["encoder"] = {"encoder_type": "vision"} + if "decoder" in data.keys(): + data["decoder"]["decoder_type"] = "transformer" + else: + data["decoder"] = {"decoder_type": "transformer"} + return data + + @model_validator(mode="before") + @classmethod + def default_architecture(cls, data: Any) -> Any: + if not (isinstance(data, dict)): + return data + if "architecture" not in data.keys(): + data["architecture"] = "vision_transformer_lm" + return data + + @model_validator(mode="after") + def _validate_vision_transformer(self): + assert not (self.add_estimator), "Estimator layer not supported in Vision Transformer" + return self + + @property + def image_size(self): + return self.encoder.image_size + + @property + def patch_size(self): + return self.encoder.patch_size + + class TransformerEncoderModelConfig(TransformerConfig, BaseModelConfig): """ Facilitate setting some transformer specific params at model level. @@ -742,6 +801,7 @@ def _validate_transformer(self): Union[ TransformerModelConfig, TransformerLMModelConfig, + VisionTransformerLMModelConfig, TransformerEncoderModelConfig, RnnModelConfig, CnnModelConfig, diff --git a/eole/config/run.py b/eole/config/run.py index bce7a86ac..fb5f2ca0a 100644 --- a/eole/config/run.py +++ b/eole/config/run.py @@ -38,6 +38,10 @@ class TrainConfig(LoggingConfig, MiscConfig, DataConfig, VocabConfig): # ModelC def get_model_path(self): return self.training.get_model_path() + @property + def data_type(self): + return self.training.data_type + @classmethod def get_defaults(cls, architecture): return cls( @@ -124,6 +128,8 @@ def _validate_predict_config(self): # TODO: do we really need this _all_transform? if self._all_transform is None: self._all_transform = self.transforms + if getattr(getattr(self.model, "encoder", None), "encoder_type", None) == "vision": + assert self.batch_size == 1, "Batch inference is not supported yet for vision models." if torch.cuda.is_available() and not self.gpu_ranks: logger.warn("You have a CUDA device, should run with -gpu_ranks") return self diff --git a/eole/encoders/__init__.py b/eole/encoders/__init__.py index 8e36b267e..d070ef5ee 100644 --- a/eole/encoders/__init__.py +++ b/eole/encoders/__init__.py @@ -4,6 +4,7 @@ from eole.encoders.rnn_encoder import RNNEncoder from eole.encoders.cnn_encoder import CNNEncoder from eole.encoders.mean_encoder import MeanEncoder +from eole.encoders.vision import VisionEncoder str2enc = { @@ -12,4 +13,5 @@ "cnn": CNNEncoder, "transformer": TransformerEncoder, "mean": MeanEncoder, + "vision": VisionEncoder, } diff --git a/eole/encoders/vision.py b/eole/encoders/vision.py new file mode 100644 index 000000000..2e90cdfbf --- /dev/null +++ b/eole/encoders/vision.py @@ -0,0 +1,148 @@ +""" +Base class for vision encoders. +We'll start with Pixtral encoder, but we can build up on it later. +https://github.com/mistralai/mistral-inference/pull/217/files#diff-5cf8b83293298d0bdda514a77cae34565cb54571cddab02801d99a52f0f71a66 +""" + +import torch +import torch.nn as nn +from typing import Optional + +# from eole.modules.multi_headed_attn import SelfMHA +from eole.modules.rmsnorm import RMSNorm +from eole.encoders.transformer import TransformerEncoderLayer +from eole.constants import PositionEncodingType +from eole.modules.rope import RotaryPosition + + +def position_ids_in_meshgrid(patch_embeds_list, max_width): + positions = [] + for patch in patch_embeds_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +def create_block_diagonal_mask(lengths, device): + """ + Create a block diagonal mask based on sequence lengths. + Args: + lengths (list of int): List of lengths for each block. + Returns: + torch.Tensor: A 2D boolean tensor with True in the + block diagonal regions and False elsewhere. + """ + total_length = sum(lengths) + mask = torch.zeros((total_length, total_length), dtype=torch.bool) + + start = 0 + for length in lengths: + end = start + length + mask[start:end, start:end] = True + start = end + + return mask.to(device) + + +class VisionEncoder(nn.Module): + def __init__(self, model_config, running_config=None): + super(VisionEncoder, self).__init__() + self.model_config = model_config + if model_config.position_encoding_type == PositionEncodingType.Rotary: + self.rope = RotaryPosition(model_config, mode="2d") + self.patch_conv = nn.Conv2d( + in_channels=model_config.num_channels, + out_channels=model_config.hidden_size, + kernel_size=model_config.patch_size, + stride=model_config.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(model_config.hidden_size, eps=1e-5) + self.transformer_layers = torch.nn.ModuleList() + for _ in range(model_config.layers): + self.transformer_layers.append(TransformerEncoderLayer(model_config, running_config=running_config)) + + head_dim = model_config.hidden_size // model_config.heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: Optional[torch.Tensor] = None + + @classmethod + def from_config(cls, model_config, running_config=None): + """Alternate constructor.""" + return cls( + model_config, + running_config, + ) + + @property + def max_patches_per_side(self): + return self.model_config.image_size // self.model_config.patch_size + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, images): + """ + Args: + images: list of N_img images of variable sizes, each of shape (C, H, W) + Returns: + image_features: tensor of token features for all tokens of all images of + shape (N_toks, D) + """ + + # TODO add as @property somewhere + dtype = next(self.parameters()).dtype + + # pass images through initial convolution independently + patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(dtype)).squeeze(0) for img in images] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0) + patch_embeds = self.ln_pre(patch_embeds) + # should probably be handled upstream to have proper batch dim + patch_embeds = patch_embeds.unsqueeze(0) + + # positional embeddings + positions = position_ids_in_meshgrid( + patch_embeds_list, + max_width=self.model_config.image_size // self.model_config.patch_size, + ) + if hasattr(self, "rope"): + position_embeddings = self.rope.update( + patch_embeds.size(1), + step=0, + reset=True, + positions=positions, + ) + else: + position_embeddings = None + + # pass through Transformer with a block diagonal mask delimiting images + mask = create_block_diagonal_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + device=self.device, + ) + # revert for proper downstream compatibility + mask = ~mask + out = patch_embeds + for i, layer in enumerate(self.transformer_layers): + out = layer(out, pad_mask=mask, position_embeddings=position_embeddings) + + # remove batch dimension of the single sequence + return out + + +# Multi-Modal Projector +class VisionLanguageAdapter(nn.Module): + def __init__(self, in_dim, out_dim): + super(VisionLanguageAdapter, self).__init__() + self.w_in = nn.Linear(in_dim, out_dim, bias=True) + self.gelu = nn.GELU() + self.w_out = nn.Linear(out_dim, out_dim, bias=True) + + def forward(self, x): + return self.w_out(self.gelu(self.w_in(x))) diff --git a/eole/inputters/dynamic_iterator.py b/eole/inputters/dynamic_iterator.py index cb6c42177..9475d1b48 100644 --- a/eole/inputters/dynamic_iterator.py +++ b/eole/inputters/dynamic_iterator.py @@ -318,7 +318,7 @@ def batch_size_fn(nbsents, maxlen): def max_src_tgt(ex): """return the max tokens btw src and tgt in the sequence.""" - if ex["tgt"]: + if ex.get("tgt", None) is not None: return max(len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"])) return len(ex["src"]["src_ids"]) @@ -393,7 +393,10 @@ def __iter__(self): "cid_line_number", "left_pad", ]: - tensor_batch[key] = tensor_batch[key].to(self.device) + if isinstance(tensor_batch[key], list): + tensor_batch[key] = [t.to(self.device) for t in tensor_batch[key]] + else: + tensor_batch[key] = tensor_batch[key].to(self.device) yield (tensor_batch, bucket_idx) diff --git a/eole/inputters/image_utils.py b/eole/inputters/image_utils.py new file mode 100644 index 000000000..b90cff41b --- /dev/null +++ b/eole/inputters/image_utils.py @@ -0,0 +1,67 @@ +import numpy as np +from PIL import Image +from typing import Tuple + +""" +Most of this code is borrowed from: +https://github.com/mistralai/mistral-common/blob/main/src/mistral_common/tokens/tokenizers/multimodal.py +""" + +# This is strictly restricted to pixtral-12b, extending to other models will require some mapping +DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) # RGB +DATASET_STD = (0.26862954, 0.26130258, 0.27577711) # RGB + + +def _convert_to_rgb(image: Image.Image) -> Image.Image: + """ + Convert a PIL image to RGB. + We ensure transparent background becomes white. + """ + if image.mode == "RGB": + return image + if image.mode != "RGBA": + image = image.convert("RGBA") + white_bg: Image.Image = Image.new("RGBA", image.size, "WHITE") + white_bg.paste(image, (0, 0), image) + return white_bg.convert("RGB") + + +def normalize(np_image, mean, std): + # np_image = np_image / 255.0 + assert len(np_image.shape) == 3, f"{np_image.shape=}" + assert np_image.shape[2] == len(mean) == len(std), f"{np_image.shape=}, {mean=}, {std=}" + mean = np.array(mean, dtype=np_image.dtype) + std = np.array(std, dtype=np_image.dtype) + image = (np_image - mean) / std + return image.transpose(2, 0, 1) + + +def transform_image(image: Image.Image, new_size: Tuple[int, int]) -> np.ndarray: + # resize PIL image directly + resized_image = image.resize(new_size, resample=3, reducing_gap=None) # hardcoded kwargs, reproducing HF + np_image = np.array(_convert_to_rgb(resized_image), dtype=np.uint8) + # rescale (dtype shennanigans to match HF processing) + np_image = (np_image.astype(np.float64) * 1 / 255).astype(np.float32) + return normalize(np_image, DATASET_MEAN, DATASET_STD) + + +def image_to_num_tokens(img, image_size=1024, image_patch_size=16): + w, h = img.size + ratio = max(h / image_size, w / image_size) + if ratio > 1: + w = round(w / ratio) + h = round(h / ratio) + width_tokens = (w - 1) // image_patch_size + 1 + height_tokens = (h - 1) // image_patch_size + 1 + return width_tokens, height_tokens + + +def process_image(image_path, image_size=1024, image_patch_size=16): + image = Image.open(image_path) + w, h = image_to_num_tokens(image, image_size=image_size, image_patch_size=image_patch_size) + new_image_size = (w * image_patch_size, h * image_patch_size) + # TODO retrieve from model config / vocab / tokenizer + image_tokens = (["[IMG]"] * w + ["[IMG_BREAK]"]) * h + image_tokens[-1] = "[IMG_END]" + processed_image = transform_image(image, new_image_size) + return {"image": processed_image, "tokens": image_tokens} diff --git a/eole/inputters/text_corpus.py b/eole/inputters/text_corpus.py index dcb114814..5f8803ec3 100644 --- a/eole/inputters/text_corpus.py +++ b/eole/inputters/text_corpus.py @@ -6,8 +6,10 @@ from eole.constants import CorpusName, CorpusTask from eole.transforms import TransformPipe from eole.inputters.text_utils import transform_bucket +from eole.inputters.image_utils import process_image from contextlib import contextmanager import itertools +import json from datasets import load_dataset @@ -93,6 +95,48 @@ def __str__(self): return f"{cls_name}({self.id}, {self.file_path}, {self.file_path}" f"align={None}" +class ImageTextCorpus(object): + """ + Handle vision data looking like this: + ``` + [ + { + "text": "List the top 5 countries in Europe with the highest GDP {image1}", + "images": { + "image1": "./test_data/gdp.png" + } + }, + ... + ] + ``` + """ + + def __init__(self, name, data, is_train=False): + self.id = name + self.data = data + self.is_train = is_train + + def load(self, offset=0, stride=1): + def make_ex(item): + example = {"text": item["text"], "images": item.get("images", {})} + return example + + if isinstance(self.data, list): + for i, item in enumerate(self.data): + if (i // stride) % stride == offset: + yield make_ex(item) + else: + with exfile_open(self.data, mode="rb") as fs: + data = json.load(fs) + for i, item in enumerate(data): + if (i // stride) % stride == offset: + yield make_ex(item) + + def __str__(self): + cls_name = type(self).__name__ + return f"{cls_name}({self.id}, {self.path}" + + class ParallelCorpus(object): """A parallel corpus file pair that can be loaded to iterate.""" @@ -210,34 +254,52 @@ def get_corpora(config, task=CorpusTask.TRAIN, src=None, tgt=None, align=None): corpus_dict.path_sco, corpus_dict.path_align, ) - else: + elif config.data_type == "text": corpora_dict[corpus_id] = BlockwiseCorpus( corpus_id, corpus_dict.path_txt, block_size=8192, # number of characters ) + elif config.data_type == "image": + corpora_dict[corpus_id] = ImageTextCorpus( + corpus_id, + corpus_dict.path_txt, + is_train=True, + ) elif task == CorpusTask.VALID: if CorpusName.VALID in config.data.keys(): if config.data[CorpusName.VALID].path_tgt is None: path_tgt = config.data[CorpusName.VALID].path_src else: path_tgt = config.data[CorpusName.VALID].path_tgt - corpora_dict[CorpusName.VALID] = ParallelCorpus( - CorpusName.VALID, - config.data[CorpusName.VALID].path_src, - path_tgt if tgt is None else None, - None, - config.data[CorpusName.VALID].path_align, - ) + if config.data_type == "text": + corpora_dict[CorpusName.VALID] = ParallelCorpus( + CorpusName.VALID, + config.data[CorpusName.VALID].path_src, + path_tgt if tgt is None else None, + None, + config.data[CorpusName.VALID].path_align, + ) + elif config.data_type == "image": + corpora_dict[CorpusName.VALID] = ImageTextCorpus( + CorpusName.VALID, + config.data[CorpusName.VALID].path_txt, + is_train=True, + ) else: return None else: - corpora_dict[CorpusName.INFER] = ParallelCorpus( - CorpusName.INFER, - src if src else config.src, - tgt if tgt else config.tgt, - align if align else None, - ) + if config.data_type == "text": + corpora_dict[CorpusName.INFER] = ParallelCorpus( + CorpusName.INFER, + src if src else config.src, + tgt if tgt else config.tgt, + align if align else None, + ) + elif config.data_type == "image": + corpora_dict[CorpusName.INFER] = ImageTextCorpus( + CorpusName.INFER, src, is_train=False # maybe homogenize to some better name + ) return corpora_dict @@ -252,7 +314,7 @@ class ParallelCorpusIterator(object): offset (int): iterate corpus with this line offset. """ - def __init__(self, corpus, transform, skip_empty_level="warning", stride=1, offset=0): + def __init__(self, corpus, transform, skip_empty_level="warning", stride=1, offset=0, is_train=False): self.cid = corpus.id self.corpus = corpus self.transform = transform @@ -299,6 +361,53 @@ def __iter__(self): yield from corpus +class ImageTextCorpusIterator(object): + def __init__( + self, + corpus, + transform, + skip_empty_level="warning", + stride=1, + offset=0, + is_train=False, + ): + self.cid = corpus.id + self.corpus = corpus + self.transform = transform + if skip_empty_level not in ["silent", "warning", "error"]: + raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}") + self.skip_empty_level = skip_empty_level + self.stride = stride + self.offset = offset + self.is_train = is_train + + def _process(self, stream): + for i, example in enumerate(stream): + # process images + processed_images = {k: process_image(v) for k, v in example.get("images", {}).items()} + text = example["text"] + image_tokens = { + # using a list here would induce unwanted whitespaces between tokens downstream + k: "".join(v["tokens"]) + for k, v in processed_images.items() + } + text = text.format(**image_tokens) + line_number = i * self.stride + self.offset + example = { + "src": text, + "tgt": text if self.is_train else None, + "images": {k: v["image"] for k, v in processed_images.items()}, + "cid": self.cid, + "cid_line_number": line_number, + } + yield example, self.transform, self.cid + + def __iter__(self): + corpus_stream = self.corpus.load(stride=self.stride, offset=self.offset) + corpus = self._process(corpus_stream) + yield from corpus + + def build_corpora_iters(corpora, transforms, corpora_info, skip_empty_level="warning", stride=1, offset=0): """Return `ParallelCorpusIterator` for all corpora defined in opts.""" corpora_iters = dict() @@ -306,12 +415,17 @@ def build_corpora_iters(corpora, transforms, corpora_info, skip_empty_level="war transform_names = corpora_info[c_id].transforms corpus_transform = [transforms[name] for name in transform_names if name in transforms] transform_pipe = TransformPipe.build_from(corpus_transform) - corpus_iter = ParallelCorpusIterator( + if isinstance(corpus, ParallelCorpus): + iterator_class = ParallelCorpusIterator + elif isinstance(corpus, ImageTextCorpus): + iterator_class = ImageTextCorpusIterator + corpus_iter = iterator_class( corpus, transform_pipe, skip_empty_level=skip_empty_level, stride=stride, offset=offset, + is_train=getattr(corpus, "is_train", None), ) corpora_iters[c_id] = corpus_iter return corpora_iters diff --git a/eole/inputters/text_utils.py b/eole/inputters/text_utils.py index b61f73978..d4a8d8221 100644 --- a/eole/inputters/text_utils.py +++ b/eole/inputters/text_utils.py @@ -9,17 +9,22 @@ def text_sort_key(ex): """Sort using the number of tokens in the sequence.""" - if ex["tgt"]: + if ex.get("tgt", None) is not None: return len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"]) return len(ex["src"]["src_ids"]) def clean_example(maybe_example): - maybe_example["src"] = {"src": " ".join(maybe_example["src"])} - if maybe_example["tgt"] is not None: + if isinstance(maybe_example["src"], list): + maybe_example["src"] = {"src": " ".join(maybe_example["src"])} + else: + maybe_example["src"] = {"src": maybe_example["src"]} + if maybe_example.get("tgt", None) is not None: maybe_example["tgt"] = {"tgt": " ".join(maybe_example["tgt"])} if "align" in maybe_example: maybe_example["align"] = " ".join(maybe_example["align"]) + if "sco" not in maybe_example: + maybe_example["sco"] = 1 return maybe_example @@ -39,7 +44,7 @@ def transform_bucket(task, bucket, threshold=0): transf_bucket = transform.batch_apply(sub_bucket, is_train=(task == CorpusTask.TRAIN), corpus_name=cid) for example, transform, cid in transf_bucket: example = clean_example(example) - if len(example["src"]["src"]) > 0 and example["sco"] > threshold: + if len(example["src"]["src"]) > 0 and example.get("sco", 1) > threshold: transformed_bucket.append(example) # at this point an example looks like: @@ -67,7 +72,7 @@ def numericalize(vocabs, example, model_type=ModelType.ENCODER_DECODER): src_text = example["src"]["src"].split(" ") if numeric["src"]["src_ids"] == []: numeric["src"]["src_ids"] = vocabs["src"](src_text) - if example["tgt"] is not None: + if example.get("tgt", None) is not None: if maybe_tgt_ids != []: numeric["tgt"]["tgt_ids"] = maybe_tgt_ids else: @@ -187,7 +192,7 @@ def tensorify(vocabs, minibatch, device, left_pad=False): device=device, ) - if minibatch[0][0]["tgt"] is not None: + if minibatch[0][0].get("tgt", None) is not None: if left_pad: tbatchtgt = [ torch.tensor(ex["tgt"]["tgt_ids"], dtype=torch.long, device=device).flip(dims=[0]) @@ -244,6 +249,19 @@ def tensorify(vocabs, minibatch, device, left_pad=False): alignment[i, : len(ex["alignment"])] = torch.tensor(ex["alignment"], dtype=torch.long, device=device) tensor_batch["alignment"] = alignment + if "images" in minibatch[0][0].keys(): + tensor_batch["images"] = [ + torch.tensor(v, device=device, dtype=torch.float32) + for ex, indice in minibatch + for k, v in ex["images"].items() + # BATCH > 1 not supported yet + # [ + # torch.tensor(v, device=device, dtype=torch.float32) + # for k, v in ex["images"].items() + # ] + # for ex, indice in minibatch + ] + tensor_batch["ind_in_bucket"] = [indice for ex, indice in minibatch] tensor_batch["cid"] = [ex["cid"] for ex, indice in minibatch] diff --git a/eole/models/model.py b/eole/models/model.py index 41563a521..6ae00e942 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -26,6 +26,8 @@ from eole.models.model_saver import load_checkpoint from eole.modules.estimator import FeedForward +from eole.encoders.vision import VisionLanguageAdapter, VisionEncoder + class NoOpPosition: """A no-op position encoding callable.""" @@ -120,6 +122,7 @@ class BaseModel(nn.Module): def __init__(self, **kwargs): super(BaseModel, self).__init__() self.encoder = kwargs.get("encoder", None) + self.adapter = kwargs.get("adapter", None) self.decoder = kwargs.get("decoder", None) self.src_emb = kwargs.get("src_emb", None) self.tgt_emb = kwargs.get("tgt_emb", None) @@ -128,7 +131,8 @@ def __init__(self, **kwargs): self.rope = kwargs.get("rope", None) self.share_decoder_embeddings = False if self.encoder is not None and self.src_emb is None: - raise ValueError("An Encoder needs source Embeddings") + if not isinstance(self.encoder, VisionEncoder): + raise ValueError("An Encoder needs source Embeddings") if self.decoder is not None and self.tgt_emb is None: raise ValueError("A Decoder needs target Embeddings") if self.encoder is None and self.decoder is None: @@ -896,11 +900,108 @@ def update_dropout(self, dropout, attention_dropout): self.encoder.update_dropout(dropout, attention_dropout) +class VisionEncoderDecoderModel(BaseModel): + """VisionEncoderDecoderModel Class + See :class:`~eole.models.BaseModel` for options.""" + + def __init__(self, **kwargs): + super(VisionEncoderDecoderModel, self).__init__(**kwargs) + self.tgt_shift = 1 + self.image_token_id = kwargs.get("image_token_id", None) + if self.encoder is None or self.decoder is None: + raise ValueError("A EncoderDecoderModel requires both an Encoder and a Decoder") + # TODO: make this compatible? + # if self.add_estimator: + # self.estimator = FeedForward(self.hidden_size) + + @classmethod + def build_blocks(cls, model_config, vocabs, running_config=None): + encoder = build_encoder(model_config, running_config=running_config) + adapter = VisionLanguageAdapter(model_config.encoder.hidden_size, model_config.decoder.hidden_size) + tgt_emb = build_tgt_emb( + model_config, + vocabs, + running_config=running_config, + share_embeddings=model_config.share_embeddings, + ) + decoder = build_decoder(model_config, running_config=running_config) + return cls( + encoder=encoder, + decoder=decoder, + adapter=adapter, + tgt_emb=tgt_emb, + add_estimator=model_config.add_estimator, + hidden_size=model_config.decoder.hidden_size, + image_token_id=model_config.encoder.image_token_id, + rope=build_rope(model_config), + ) + # from there, the base blocks exist, and the rest is done in the from_opt from base class + + def embed_vision_language_features(self, src, images): + # TODO: test with batch > 1? + batch_size = src.size(0) + text_locations = src != self.image_token_id + image_locations = src == self.image_token_id + text_features = self.tgt_emb(src[text_locations].view(batch_size, -1)) + if len(images) == 0: + return text_features + encoded_images = self.encoder(images) + image_features = self.adapter(encoded_images) + + seq_len = src.shape[1] + batch, N_txt, D_txt = text_features.shape + _, N_img, D_img = image_features.shape + assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" + assert seq_len == N_txt + N_img, ( + f"seq_len {seq_len} should be equal to N_txt + N_img " f"{(N_txt, N_img, image_locations.sum().item())}" + ) + + combined_features = torch.empty( + (batch, seq_len, D_txt), + dtype=text_features.dtype, + device=text_features.device, + ) + combined_features[text_locations, :] = text_features + if len(images) > 0: + combined_features[image_locations, :] = image_features + + return combined_features + + def forward(self, src, tgt, src_len, bptt=False, with_align=False, images=[]): + """A DecoderModel forward the src side to the decoder along + with the source lengths vector. It is a decoder only LM (cf GPT-2)""" + + if not bptt: + self.decoder.init_state() + emb = self.embed_vision_language_features(src, images) + pad_idx = self.tgt_emb.word_padding_idx + pad_mask = src.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] + position_embeddings = self.rope.update(emb.size(1), step=None) + dec_out, attns = self.decoder( + emb, + enc_out=None, + src_len=src_len, + with_align=with_align, + tgt_pad_mask=pad_mask, + position_embeddings=position_embeddings, + ) + + return dec_out, attns, None + + def update_dropout(self, dropout, attention_dropout): + self.encoder.update_dropout(dropout, attention_dropout) + self.src_emb.update_dropout(dropout) + self.decoder.update_dropout(dropout, attention_dropout) + self.tgt_emb.update_dropout(dropout) + + def get_model_class(model_config): # might have more cases later if model_config.decoder is None: return EncoderModel elif model_config.encoder is None: return DecoderModel + elif model_config.encoder.encoder_type == "vision": + return VisionEncoderDecoderModel else: return EncoderDecoderModel diff --git a/eole/modules/multi_headed_attn.py b/eole/modules/multi_headed_attn.py index bf6dea100..7fbb5ffb6 100644 --- a/eole/modules/multi_headed_attn.py +++ b/eole/modules/multi_headed_attn.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from math import sqrt from torch import Tensor from typing import Optional, Tuple from torch.utils.checkpoint import checkpoint @@ -21,7 +20,12 @@ def shape(x: Tensor, dim_per_head: int) -> Tensor: """[batchsize x length x modeldim] -> [batchsize x heads x length x dimperhead] """ - x_0, x_1, _ = x.size() + # patch for images + if len(x.size()) == 2: + x_0 = 1 + x_1, _ = x.size() + else: + x_0, x_1, _ = x.size() return x.view(x_0, x_1, -1, dim_per_head).transpose(1, 2) @@ -77,7 +81,7 @@ def __init__(self, model_config, running_config=None, is_decoder: bool = True) - self.dim_per_head = model_config.dim_per_head self.heads = model_config.heads self.heads_kv = model_config.heads_kv if model_config.heads_kv is not None else model_config.heads - self.parallel_gpu = running_config.parallel_gpu + self.parallel_gpu = getattr(running_config, "parallel_gpu", 1) assert ( self.dim_per_head * self.heads_kv @@ -247,9 +251,8 @@ def _compute_attention( ) attn = None else: - query /= sqrt(self.dim_per_head) # batch x num_heads x query_len x key_len - scores = torch.matmul(query, key.transpose(2, 3)) + scores = torch.matmul(query, key.transpose(2, 3)) * self.dim_per_head**-0.5 if self.relative_attention_bias is not None: q_len = key.size(2) if self.layer_cache[0] else query.size(2) @@ -287,13 +290,15 @@ def _compute_attention( scores = scores.float() if attn_mask is not None: + if len(attn_mask.size()) == 2: + attn_mask = attn_mask[None, None, :, :] # not 100% necessary but expand to nb of heads attn_mask = attn_mask.expand(-1, self.heads // self.parallel_gpu, -1, -1) # now mask and scores have the same shape - scores = scores.masked_fill(~attn_mask, -1e18) + scores = scores.masked_fill(~attn_mask, torch.finfo(scores.dtype).min) # 3) Apply attention dropout and compute context vectors. - attn = F.softmax(scores, dim=-1).to(query.dtype) + attn = F.softmax(scores, dim=-1, dtype=torch.float32).to(query.dtype) drop_attn = self.dropout(attn) if self.dropout_p > 0 else attn attn_output = torch.matmul(drop_attn, value) diff --git a/eole/modules/rope.py b/eole/modules/rope.py index be747aafc..3fb2b4dae 100644 --- a/eole/modules/rope.py +++ b/eole/modules/rope.py @@ -67,7 +67,7 @@ class RotaryPosition(nn.Module): and to support future enhancements, such as additional scaling types. """ - def __init__(self, model_config): + def __init__(self, model_config, mode="1d"): """ Initializes the RotaryPosition module. @@ -91,6 +91,7 @@ def __init__(self, model_config): """ super(RotaryPosition, self).__init__() self.model_config = model_config + self.mode = mode self.dim_per_head = model_config.dim_per_head if model_config.rope_config.rotary_dim == 0: rotary_dim = self.dim_per_head @@ -98,11 +99,35 @@ def __init__(self, model_config): rotary_dim = model_config.rope_config.rotary_dim self.rotary_interleave = model_config.rope_config.rotary_interleave self.rotary_theta = model_config.rope_config.rotary_theta - self.inv_freq = 1.0 / (self.rotary_theta ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + inv_freq = 1.0 / (self.rotary_theta ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)) + if mode == "2d": + inv_freq = self.init_2d_inv_freq(inv_freq) + self.inv_freq = inv_freq # TODO: extend with other scaling types if getattr(self.model_config.rope_config, "scaling_type", None) == "llama3": self.llama3_scaling() - self.update(1024) + cos, sin = self.update(1024) + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + + def init_2d_inv_freq(self, inv_freq): + """ + Initialize 2d inverse frequencies for pixtral rotary embeddings. + """ + max_patches_per_side = self.model_config.image_size // self.model_config.patch_size + h = torch.arange(max_patches_per_side, device=inv_freq.device) + w = torch.arange(max_patches_per_side, device=inv_freq.device) + freqs_h = torch.outer(h, inv_freq[::2]).float() + freqs_w = torch.outer(w, inv_freq[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, self.dim_per_head // 2) + inv_freq = torch.cat((inv_freq, inv_freq), dim=-1) + return inv_freq def llama3_scaling(self): """ @@ -134,29 +159,12 @@ def llama3_scaling(self): is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) self.inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - def update(self, maxseqlen, step=0, prefetch=1024): - """ - Computes the rotary position embeddings for a given input. - - Args: - maxseqlen: max seq length of the input embeddings. - step: The current step or position within the sequence. Defaults to 0. - offset: An optional offset to apply to the position indices. - This is used for the specific `flash_attn_with_kvcache` path, - which requires processes by chunks of 32 tokens. Defaults to 0. - - Returns: - torch.Tensor: A tensor containing the computed rotary embeddings. - - Notes: - - The returned tensor contains cosine and sine values representing the - rotary embeddings, concatenated along the last dimension. - - The output tensor's dimensions are `[maxseqlen, dim]`, where `dim` is - twice the size of the original inverse frequency tensor (`inv_freq`). - """ + def forward_1d(self, maxseqlen, step=0, prefetch=1024): + if step is None: + step = 0 offset = 32 # make sure we have at least 32 positions for flash_attn_with_kvcache if step == 0: - maxseqlen = 1024 # reset as in init() with self.update(1024) + maxseqlen = max(maxseqlen, 1024) # reset as in init() with self.update(1024) elif hasattr(self, "cos") and self.cos.size(0) >= max(offset + (step or 0), 0) + maxseqlen: return self.cos, self.sin @@ -173,3 +181,67 @@ def update(self, maxseqlen, step=0, prefetch=1024): self.register_buffer("cos", cos, persistent=False) self.register_buffer("sin", sin, persistent=False) return cos, sin + + # TODO: investigate if this is useful / properly implemented + # def _dynamic_frequency_update(self, position_ids, device): + # """ + # dynamic RoPE layers should recompute `inv_freq` in the following situations: + # 1 - growing beyond the cached sequence length (allow scaling) + # 2 - the current sequence length is in the original scale + # (avoid losing precision with small sequences) + # """ + # seq_len = torch.max(position_ids) + 1 + # if seq_len > self.max_seq_len_cached: # growth + # inv_freq, self.attention_scaling = self.rope_init_fn( + # self.config, device, seq_len=seq_len, **self.rope_kwargs + # ) + # self.inv_freq = inv_freq + # self.max_seq_len_cached = seq_len + + # if (seq_len < self.original_max_seq_len + # and self.max_seq_len_cached > self.original_max_seq_len): # reset + # self.inv_freq = self.original_inv_freq + # self.max_seq_len_cached = self.original_max_seq_len + + def forward_2d(self, maxseqlen, step=0, prefetch=1024, positions=None): + # TODO: maybe do scaling here + device = self.cos.device if hasattr(self, "cos") else torch.device("cpu") + if step is None: + step = 0 + if positions is None: + tmax = torch.arange(maxseqlen, device=self.inv_freq.device) + else: + tmax = positions + rope = self.inv_freq[tmax].to(device) + # rope is now matrix [maxseqlen, dim/2] + # if device is not None: + # rope = rope.to(device) + cos = rope.cos() + sin = rope.sin() + return cos, sin + + def update(self, maxseqlen, step=0, prefetch=1024, reset=False, positions=None): + """ + Computes the rotary position embeddings for a given input. + Args: + maxseqlen: max seq length of the input embeddings. + step: The current step or position within the sequence. Defaults to 0. + offset: An optional offset to apply to the position indices. + This is used for the specific `flash_attn_with_kvcache` path, + which requires processes by chunks of 32 tokens. Defaults to 0. + Returns: + torch.Tensor: A tensor containing the computed rotary embeddings. + Notes: + - The returned tensor contains cosine and sine values representing the + rotary embeddings, concatenated along the last dimension. + - The output tensor's dimensions are `[maxseqlen, dim]`, where `dim` is + twice the size of the original inverse frequency tensor (`inv_freq`). + """ + if reset: + self.rope = None + if self.mode == "1d": + return self.forward_1d(maxseqlen, step=step, prefetch=prefetch) + elif self.mode == "2d": + return self.forward_2d(maxseqlen, step=step, prefetch=prefetch, positions=positions) + else: + raise NotImplementedError diff --git a/eole/modules/transformer_mlp.py b/eole/modules/transformer_mlp.py index b2289ef24..a49624fdd 100644 --- a/eole/modules/transformer_mlp.py +++ b/eole/modules/transformer_mlp.py @@ -20,7 +20,7 @@ def __init__( model_config, running_config=None, ): - self.parallel_gpu = running_config.parallel_gpu + self.parallel_gpu = getattr(running_config, "parallel_gpu", 1) super(MLP, self).__init__() assert ( model_config.transformer_ff % self.parallel_gpu == 0 diff --git a/eole/predict/__init__.py b/eole/predict/__init__.py index 0009376fb..2ed6c012e 100644 --- a/eole/predict/__init__.py +++ b/eole/predict/__init__.py @@ -15,6 +15,8 @@ def get_infer_class(model_config): return Encoder elif model_config.encoder is None: return GeneratorLM + elif model_config.encoder.encoder_type == "vision": + return GeneratorLM else: return Translator diff --git a/eole/predict/generator.py b/eole/predict/generator.py index d76cf8e19..0f36bffbf 100644 --- a/eole/predict/generator.py +++ b/eole/predict/generator.py @@ -135,6 +135,7 @@ def _predict_batch_with_strategy(self, batch, decode_strategy): src_len=decode_strategy.src_len, step=step if step == 0 else step + max(src_len.tolist()), left_pad=batch["left_pad"], + images=batch.get("images", None), ) if step == 0: @@ -197,7 +198,6 @@ def _score_target(self, batch, enc_out, src_len): log_probs, attn = self._decode_and_generate( src, None, - batch, src_len=src_len, ) diff --git a/eole/predict/inference.py b/eole/predict/inference.py index b48db1a7a..eb137370c 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -607,6 +607,7 @@ def _decode_and_generate( step=None, return_attn=False, left_pad=False, + images=None, ): # Decoder forward, takes [batch, tgt_len, nfeats] as input @@ -622,10 +623,16 @@ def _decode_and_generate( src_pad_mask = sequence_mask(src_len, src_max_len).unsqueeze(1) # [B, 1, T_src] else: src_pad_mask = None + + if images is not None and step == 0: + emb = self.model.embed_vision_language_features(decoder_in, images) + else: + emb = self.model.tgt_emb(decoder_in, step=step) + tgt_pad_mask = decoder_in.eq(self._tgt_pad_idx).unsqueeze(1) # [B, 1, T_tgt] position_embeddings = self.model.rope.update(decoder_in.size(1), step=step) dec_out, dec_attn = self.model.decoder( - self.model.tgt_emb(decoder_in, step=step), + emb, enc_out=enc_out, src_len=src_len, step=step, diff --git a/eole/trainer.py b/eole/trainer.py index 03bc7a729..6c8196957 100644 --- a/eole/trainer.py +++ b/eole/trainer.py @@ -459,10 +459,13 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, repor report_stats.n_src_words += src_len.sum().item() total_stats.n_src_words += src_len.sum().item() tgt = batch["tgt"] + kwargs = {} + if "images" in batch.keys(): + kwargs["images"] = batch["images"] try: with get_autocast(enabled=self.optim.amp): - model_out, attns, estim = self.model(src, tgt, src_len, with_align=self.with_align) + model_out, attns, estim = self.model(src, tgt, src_len, with_align=self.with_align, **kwargs) if self.zero_out_prompt_loss: # The loss of the prompt will be set to zero. batch = self.train_loss.ignore_prompt(batch) diff --git a/eole/transforms/tokenize_id.py b/eole/transforms/tokenize_id.py index 27fa5c8ca..40cd32342 100644 --- a/eole/transforms/tokenize_id.py +++ b/eole/transforms/tokenize_id.py @@ -100,10 +100,18 @@ def tokenize_string(self, string, side="src", is_train=False): return tokens def apply(self, example, is_train=False, stats=None, **kwargs): - src_tokens = self.tokenize_string(" ".join(example["src"]), side="src", is_train=is_train) + if isinstance(example["src"], str): + src_tokens = self.tokenize_string(example["src"], side="src", is_train=is_train) + elif isinstance(example["src"], list): + src_tokens = self.tokenize_string(" ".join(example["src"]), side="src", is_train=is_train) + else: + raise ValueError(f"Unsupported src type: {type(example['src'])}") example["src_ids"] = src_tokens if example.get("tgt", None) is not None: - tgt_tokens = self.tokenize_string(" ".join(example["tgt"]), side="tgt", is_train=is_train) + if isinstance(example["tgt"], str): + tgt_tokens = self.tokenize_string(example["tgt"], side="tgt", is_train=is_train) + elif isinstance(example["tgt"], list): + tgt_tokens = self.tokenize_string(" ".join(example["tgt"]), side="tgt", is_train=is_train) example["tgt_ids"] = tgt_tokens return example diff --git a/recipes/pixtral/README.md b/recipes/pixtral/README.md new file mode 100644 index 000000000..ec338c37b --- /dev/null +++ b/recipes/pixtral/README.md @@ -0,0 +1,21 @@ +# Pixtral + +Only the mistral-community/pixtral-12b version is supported right now. The official mistralai/Pixtral-12B-2409 config is rather incomplete. + +## Convert the model + +``` +eole convert HF --model_dir mistral-community/pixtral-12b --output ./pixtral-12b --token $HF_TOKEN +``` + +## Run the test script + +``` +python3 test_inference.py +``` + +There are several examples in the test script (taken from pixtral blog posts). A single one is activated by default, but you can uncomment the others to test the various cases. + +## Finetuning + +Finetuning is untested for now. Feel free to try it out and fix any arising issues. \ No newline at end of file diff --git a/recipes/pixtral/finetune.yaml b/recipes/pixtral/finetune.yaml new file mode 100644 index 000000000..0b116e739 --- /dev/null +++ b/recipes/pixtral/finetune.yaml @@ -0,0 +1,89 @@ +# General settings +seed: 1234 +share_vocab: true +save_data: "./finetune/pixtral-finetune" +src_vocab: "./pixtral-12b/vocab.txt" +src_vocab_size: 132000 +tgt_vocab_size: 132000 + +overwrite: true + +report_every: 10 + +# datasets +data: + test_data: + path_txt: "./train_data.json" + + # valid: + # path_src: "./data/valid.txt" + +skip_empty_level: silent + +transforms_configs: + huggingface_tokenize: + max_length: 4096 + +training: + data_type: "image" + # GPU dispatching + world_size: 1 + gpu_ranks: [0] + # 2 GPU + # world_size: 2 + # gpu_ranks: [0, 1] + # parallel_mode: tensor_parallel + dropout_steps: [0] + dropout: [0.0] + attention_dropout: [0.0] + # Batching + # bucket_size: 32768 + bucket_size: 10 + num_workers: 0 + batch_type: "sents" + batch_size: 1 + valid_batch_size: 1 + batch_size_multiple: 1 + + # Optimization + # compute_dtype: "fp16" + # apex_opt_level: "" + # optim: "fusedadam" + compute_dtype: "bf16" + optim: "adam" + learning_rate: 0.0001 + warmup_steps: 100 + decay_method: "none" + #learning_rate_decay: 0.98 + #start_decay_steps: 100 + #decay_steps: 10 + adam_beta2: 0.998 + accum_count: [8] + accum_steps: [0] + max_grad_norm: 0 + label_smoothing: 0.0 + param_init_method: xavier_uniform + normalization: "tokens" + + # folders + train_from: "./pixtral-12b" + model_path: "./finetune/pixtral-finetuned" + keep_checkpoint: 10 + save_checkpoint_steps: 100 + + train_steps: 1000 + valid_steps: 100 + + # 4/8bit + quant_layers: ['gate_up_proj', 'down_proj', 'up_proj', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'] + quant_type: "bnb_NF4" + + # LoRa + lora_layers: ['linear_values', 'linear_query', 'linear_keys', 'final_linear'] + lora_rank: 2 + lora_dropout: 0.05 + lora_alpha: 8 + lora_embedding: false + + # Chekpointing + #use_ckpting: ['ffn', 'lora'] diff --git a/recipes/pixtral/test_data/gdp.png b/recipes/pixtral/test_data/gdp.png new file mode 100644 index 000000000..e70ff6691 Binary files /dev/null and b/recipes/pixtral/test_data/gdp.png differ diff --git a/recipes/pixtral/test_data/image-to-code.jpg b/recipes/pixtral/test_data/image-to-code.jpg new file mode 100644 index 000000000..c457f36df Binary files /dev/null and b/recipes/pixtral/test_data/image-to-code.jpg differ diff --git a/recipes/pixtral/test_data/image1.png b/recipes/pixtral/test_data/image1.png new file mode 100644 index 000000000..81c5e469b Binary files /dev/null and b/recipes/pixtral/test_data/image1.png differ diff --git a/recipes/pixtral/test_data/image2.png b/recipes/pixtral/test_data/image2.png new file mode 100644 index 000000000..d94810ba1 Binary files /dev/null and b/recipes/pixtral/test_data/image2.png differ diff --git a/recipes/pixtral/test_data/image3.png b/recipes/pixtral/test_data/image3.png new file mode 100644 index 000000000..f125be49c Binary files /dev/null and b/recipes/pixtral/test_data/image3.png differ diff --git a/recipes/pixtral/test_data/image4.png b/recipes/pixtral/test_data/image4.png new file mode 100644 index 000000000..bc5507d8a Binary files /dev/null and b/recipes/pixtral/test_data/image4.png differ diff --git a/recipes/pixtral/test_data/loss_curve.jpg b/recipes/pixtral/test_data/loss_curve.jpg new file mode 100644 index 000000000..1d204f6e1 Binary files /dev/null and b/recipes/pixtral/test_data/loss_curve.jpg differ diff --git a/recipes/pixtral/test_data/multi-images.png b/recipes/pixtral/test_data/multi-images.png new file mode 100644 index 000000000..31656bea1 Binary files /dev/null and b/recipes/pixtral/test_data/multi-images.png differ diff --git a/recipes/pixtral/test_data/pisa_2.jpg b/recipes/pixtral/test_data/pisa_2.jpg new file mode 100644 index 000000000..786a47ab5 Binary files /dev/null and b/recipes/pixtral/test_data/pisa_2.jpg differ diff --git a/recipes/pixtral/test_data/table1.png b/recipes/pixtral/test_data/table1.png new file mode 100644 index 000000000..5ce327486 Binary files /dev/null and b/recipes/pixtral/test_data/table1.png differ diff --git a/recipes/pixtral/test_data/table2.png b/recipes/pixtral/test_data/table2.png new file mode 100644 index 000000000..9f53fa3cb Binary files /dev/null and b/recipes/pixtral/test_data/table2.png differ diff --git a/recipes/pixtral/test_inference.py b/recipes/pixtral/test_inference.py new file mode 100644 index 000000000..048f59c38 --- /dev/null +++ b/recipes/pixtral/test_inference.py @@ -0,0 +1,93 @@ +# flake8: noqa + +from rich import print +from eole.config.run import * +from eole.inference_engine import InferenceEnginePY + +config = PredictConfig( + model_path="./pixtral-12b", + src="dummy", + max_length=500, + gpu_ranks=[0], + # quant_type="bnb_NF4", + quant_type="bnb_FP4", # HF default, using it for initial reproducibility checks + quant_layers=[ + "gate_up_proj", + "down_proj", + "up_proj", + "linear_values", + "linear_query", + "linear_keys", + "final_linear", + "w_in", + "w_out", + ], + compute_dtype="fp16", + top_p=0.8, + temperature=0.35, + beam_size=1, + seed=42, + batch_size=1, + batch_type="sents", +) + +print(config) + +config.data_type = "image" +engine = InferenceEnginePY(config) + +print(engine.predictor.model) +engine.predictor.model.count_parameters() + +test_input = [ + { + "text": "[INST]List the top 5 countries in Europe with the highest GDP\n{image1}[/INST]", + "images": {"image1": "./test_data/gdp.png"}, + }, + # { + # "text": "[INST]When did things start to go wrong for dark dragon?\n{image1}[/INST]", + # "images": { + # "image1": "./test_data/loss_curve.jpg" + # } + # }, + # { + # "text": "[INST]Is this person really big, or is this building just super small?\n{image1}[/INST]", + # "images": { + # "image1": "./test_data/pisa_2.jpg" + # } + # }, + # { + # "text": "[INST]Combine information in both the tables into a single markdown table\n{image1}\n{image2}[/INST]", + # "images": { + # "image1": "./test_data/table1.png", + # "image2": "./test_data/table2.png" + # } + # }, + # { + # "text": "[INST]Combine information in both the tables into a single markdown table\n{image1}[/INST]", + # "images": { + # "image1": "./test_data/multi-images.png" + # } + # }, + # { + # "text": "[INST]Describe the images.\n{image1}\n{image2}\n{image3}\n{image4}[/INST]", + # "images": { + # "image1": "./test_data/image1.png", + # "image2": "./test_data/image2.png", + # "image3": "./test_data/image3.png", + # "image4": "./test_data/image4.png", + # } + # }, + # { + # "text": "[INST]Combine information in both the tables into a single markdown table\n{image1}{image2}[/INST]", + # "images": { + # "image1": "./test_data/table1.png", + # "image2": "./test_data/table2.png" + # } + # }, +] + +pred = engine.infer_list(test_input) + +print(pred) +print(pred[2][0][0].replace("⦅newline⦆", "\n")) diff --git a/recipes/pixtral/train_data.json b/recipes/pixtral/train_data.json new file mode 100644 index 000000000..909650280 --- /dev/null +++ b/recipes/pixtral/train_data.json @@ -0,0 +1,20 @@ +[ + { + "text": "List the top 5 countries in Europe with the highest GDP\n{image1}\nGermany\nUnited Kingdom\nFrance\nItaly\nSpain", + "images": { + "image1": "./test_data/gdp.png" + } + }, + { + "text": "When did things start to go wrong for dark dragon?\n{image1}\nAround step 10k, then worsened around step 20k.", + "images": { + "image1": "./test_data/loss_curve.jpg" + } + }, + { + "text": "Is this person really big, or is this building just super small?\n{image1}\nNone, it's a trick of perspective.", + "images": { + "image1": "./test_data/pisa_2.jpg" + } + } +] \ No newline at end of file