The repo contains the official implementation of MutBERT: Probabilistic Genome Representation Improves Genomics Foundation Models.
- 1. Introduction
- 2. Model and Results
- 3. Setup environment
- 4. Quick Start
- 5. Pre-Training
- 6. Finetune
- 7. Citation
Understanding the genomic foundation of human diversity and disease requires models that effectively capture sequence variation, such as single nucleotide polymorphisms (SNPs). While recent genomic foundation models have scaled to larger datasets and multi-species inputs, they often fail to account for the sparsity and redundancy inherent in human population data, such as those in the 1000 Genomes Project. SNPs are rare in humans, and current masked language models (MLMs) trained directly on whole-genome sequences may struggle to efficiently learn these variations. Additionally, training on the entire dataset without prioritizing regions of genetic variation results in inefficiencies and negligible gains in performance.
MutBERT, a probabilistic genome-based masked language model that efficiently utilizes SNP information from population-scale genomic data. By representing the entire genome as a probabilistic distribution over observed allele frequencies, MutBERT focuses on informative genomic variations while maintaining computational efficiency.
The all 3 pre-trained models are available at Huggingface as JadenLong/MutBERT
, JadenLong/MutBERT-Human-Ref
and JadenLong/MutBERT-Multi
. Link to HuggingFace ModelHub.
# create and activate virtual python environment
conda create -n mutbert python=3.12
conda activate mutbert
# install required packages
pip install -r requirements.txt
DNABERT-2 mainly rely on transformers==4.29.2
.
# create and activate virtual python environment
conda create -n dnabert2 python=3.8
conda activate dnabert2
# install required packages
# transformers==4.29.2 (*)
# numpy==1.24.4
# torch==2.4.1
# accelerate==1.0.1
# tqdm==4.67.0
# peft=0.13.2
# scikit_learn==1.3.2
# pandas==2.0.3
pip install package_name
Our model is easy to use with the transformers package.
To load the model from huggingface:
from transformers import AutoTokenizer, AutoModel
model_name = "JadenLong/MutBERT"
# Optional: JadenLong/MutBERT-Huamn-Ref, JadenLong/MutBERT-Multi
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
cls_model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True, num_labels=2)
The default attention is flash attention("sdpa"). If you want use basic attention, you can replace it with "eager". Please refer to here.
To get the embeddings of a dna sequence
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
model_name = "JadenLong/MutBERT"
# Optional: JadenLong/MutBERT-Huamn-Ref, JadenLong/MutBERT-Multi
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
dna = "ATCGGGGCCCATTA"
inputs = tokenizer(dna, return_tensors='pt')["input_ids"]
mut_inputs = F.one_hot(inputs, num_classes=len(tokenizer)).float().to("cpu") # len(tokenizer) is vocab size
last_hidden_state = model(mut_inputs).last_hidden_state # [1, sequence_length, 768]
# or: last_hidden_state = model(mut_inputs)[0] # [1, sequence_length, 768]
# embedding with mean pooling
embedding_mean = torch.mean(last_hidden_state[0], dim=0)
print(embedding_mean.shape) # expect to be 768
# embedding with max pooling
embedding_max = torch.max(last_hidden_state[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768
Allowed types for RoPE scaling are: linear
and dynamic
. To extend the model's context window you need to add rope_scaling parameter.
If you want to scale your model context by 2x:
from transformers import AutoModel
model_name = "JadenLong/MutBERT"
# Optional: JadenLong/MutBERT-Huamn-Ref, JadenLong/MutBERT-Multi
model = AutoModel.from_pretrained(model_name,
trust_remote_code=True,
rope_scaling={'type': 'dynamic','factor': 2.0}
) # 2.0 for x2 scaling, 4.0 for x4, etc..
The RAW training data is available:
- mutation data: Download
*.vcf.gz
. - Human Reference Genome: Download
hg38.fa.gz
After download raw data, we used bcftools to process VCF files. Link to script
You can follow 7 steps to prepare data:
- csv_post_process(): add header of csv files
- fa2npy(): extract sequence data from hg38.fa.gz, save as chr_name.npy
- split_by_n(): split sequence data by "N" from chr_name.npy, save as chr_name_part_i.npy
- create_sm_matrix(): 3rd STEP: map str to float number, create smooth matrix from chr_name_part_i.npy (str) and clean.chr_name.csv, save as chr_name_part_i.npy (float)
- cat_all_npy(): concatenate all the interval smooth matrix from chr_name_part_i.npy (float), save as train_data.npy and test_data.npy
- get_range_list(): Retrieve the list of available segment ranges.
- get_start_indices(): Generate the list of available segment ranges, which will be randomly selected as starting points for pre-training.
We used and modified run_mlm_no_trainer.py
at here.
Firstly, open your terminal and run:
accelerate config
Follow the guideline you can config accelerate.
After that, run pretrain.sh.
bash pretrain.sh
We use GUE (proposed by DNABERT-2) to conduct TFBS evaluation.
Please first download the GUE dataset from here. Then run the scripts to evaluate on all the tasks.
P.S: You should use dnabert2
environment to finetune DNABERT2 on TFBS.
We use NT-downstream Tasks to conduct this evaluation.
Run the scripts directly, it will automatically download the datasets and perform finetuning.
P.S: You should use dnabert2
environment to finetune DNABERT2 on NT-downstream Tasks.
We used and modified vep_embeddings.py
and vep_svm.ipynb
at Caduceus Model.
Run the scripts directly, it will automatically download the datasets and perform finetuning.
P.S: You should use dnabert2
environment to finetune DNABERT2 on eQTL-Vep Tasks.
If you have any question regarding our paper or codes, please feel free to start an issue or email Weicai Long (wlong381 AT connect dot hkust-gz dot edu dot cn).
If you use MutBERT in your work, please kindly cite our paper:
@article{long2025mutbert,
title={MutBERT: Probabilistic Genome Representation Improves Genomics Foundation Models},
author={Long, Weicai and Su, Houcheng and Xiong, Jiaqi and Zhang, Yanlin},
journal={bioRxiv},
pages={2025--01},
year={2025},
publisher={Cold Spring Harbor Laboratory}
}