Website
| arXiv
| 🤗 Demo
| BibTeX
Official implementation and pre-trained models for:
FlexTok: Resampling Images into 1D Token Sequences of Flexible Length, arXiv 2025
Roman Bachmann*, Jesse Allardice*, David Mizrahi*, Enrico Fini, Oğuzhan Fatih Kar, Elmira Amirloo, Alaaeldin El-Nouby, Amir Zamir, Afshin Dehghan
- Clone this repository and navigate to the root directory:
git clone https://github.com/apple/ml-flextok
cd ml-flextok
- Create a new conda environment, then install the package and its dependencies:
conda create -n flextok python=3.10 -y
source activate flextok
pip install --upgrade pip # enable PEP 660 support
pip install -e .
- Verify that CUDA is available in PyTorch by running the following in a Python shell:
# Run in Python shell
import torch
print(torch.cuda.is_available()) # Should return True
If CUDA is not available, consider re-installing PyTorch following the official installation instructions.
- (Optional) Expose the new conda environment as a kernel to Jupyter notebooks:
pip install ipykernel
python -m ipykernel install --user --name flextok --display-name "FlexTok (flextok)"
We recommend checking out the Jupyter notebook in notebooks/flextok_inference.ipynb to get started with FlexTok tokenizer and VAE inference. Please see the Model Zoo for all available FlexTok and VAE models, as well as sample code snippets that illustrate encoding and decoding.
We provide FlexTok and VAE checkpoints as safetensors, and also offer easy loading via Hugging Face Hub.
Encoder layers | Decoder layers | Dataset | HF Hub | Safetensors |
---|---|---|---|---|
12 | 12 | IN1K | EPFL-VILAB/flextok_d12_d12_in1k | Checkpoint |
18 | 18 | IN1K | EPFL-VILAB/flextok_d18_d18_in1k | Checkpoint |
18 | 28 | IN1K | EPFL-VILAB/flextok_d18_d28_in1k | Checkpoint |
18 | 28 | DFN | EPFL-VILAB/flextok_d18_d28_dfn | Checkpoint |
Example usage, loading a FlexTok d18-d28 DFN
model directly from HuggingFace Hub:
from flextok.flextok_wrapper import FlexTokFromHub
model = FlexTokFromHub.from_pretrained('EPFL-VILAB/flextok_d18_d28_dfn').eval()
The model can also be loaded by downloading the safetensors checkpoint manually and loading it using our helper functions:
from hydra.utils import instantiate
from flextok.utils.checkpoint import load_safetensors
ckpt, config = load_safetensors('/path/to/model.safetensors')
model = instantiate(config).eval()
model.load_state_dict(ckpt)
After loading a FlexTok model, image batches can be encoded using:
from flextok.utils.demo import imgs_from_urls
# Load example images of shape (B, 3, 256, 256), normalized to [-1,1]
imgs = imgs_from_urls(urls=['https://storage.googleapis.com/flextok_site/nb_demo_images/0.png'])
# tokens_list is a list of [1, 256] discrete token sequences
tokens_list = model.tokenize(imgs)
The list of token sequences can be truncated in a nested fashion:
k_keep = 64 # For example, only keep the first 64 out of 256 tokens
tokens_list = [t[:,:k_keep] for t in tokens_list]
To decode the tokens with FlexTok's rectified flow decoder, call:
# tokens_list is a list of [1, l] discrete token sequences, with l <= 256
# reconst is a [B, 3, 256, 256] tensor, normalized to [-1,1]
reconst = model.detokenize(
tokens_list,
timesteps=20, # Number of denoising steps
guidance_scale=7.5, # Classifier-free guidance scale
perform_norm_guidance=True, # See https://arxiv.org/abs/2410.02416
)
Latent channels | Downsampling factor | HF Hub | Safetensors |
---|---|---|---|
4 | 8 | EPFL-VILAB/flextok_vae_c4 | Checkpoint |
8 | 8 | EPFL-VILAB/flextok_vae_c8 | Checkpoint |
16 | 8 | EPFL-VILAB/flextok_vae_c16 | Checkpoint |
Example usage, loading an AutoencoderKL
directly from HuggingFace Hub and autoencoding a sample image:
from diffusers.models import AutoencoderKL
from flextok.utils.demo import imgs_from_urls
vae = AutoencoderKL.from_pretrained(
'EPFL-VILAB/flextok_vae_c16', low_cpu_mem_usage=False
).eval()
# Load image of shape (B, 3, H, W), normalized to [-1,1]
imgs = imgs_from_urls(urls=['https://storage.googleapis.com/four_m_site/images/demo_rgb.png'])
# Autoencode with the VAE
latents = vae.encode(imgs).latent_dist.sample() # Shape (B, D, H//8, W//8) with D in 4, 8, 16
reconst = vae.decode(latents).sample # Shape (B, 3, H, W)
The code in this repository is released under the license as found in the LICENSE file.
The model weights in this repository are released under the Apple Machine Learning Research Model license as found in the LICENSE_WEIGHTS file.
If you find this repository helpful, please consider citing our work:
@article{flextok,
title={{FlexTok}: Resampling Images into 1D Token Sequences of Flexible Length},
author={Roman Bachmann and Jesse Allardice and David Mizrahi and Enrico Fini and O{\u{g}}uzhan Fatih Kar and Elmira Amirloo and Alaaeldin El-Nouby and Amir Zamir and Afshin Dehghan},
journal={arXiv 2025},
year={2025},
}