CVPR 2025 ✨ Highlight · 📄 Paper
Tommie Kerssies1, Niccolò Cavagnero2,*, Alexander Hermans3, Narges Norouzi1, Giuseppe Averta2, Bastian Leibe3, Gijs Dubbelman1, Daan de Geus1,3
¹ Eindhoven University of Technology
² Polytechnic of Turin
³ RWTH Aachen University
* Work done while visiting RWTH Aachen University
We present the Encoder-only Mask Transformer (EoMT), a minimalist image segmentation model that repurposes a plain Vision Transformer (ViT) to jointly encode image patches and segmentation queries as tokens. No adapters. No decoders. Just the ViT.
Leveraging large-scale pre-trained ViTs, EoMT achieves accuracy similar to state-of-the-art methods that rely on complex, task-specific components. At the same time, it is significantly faster thanks to its simplicity, for example up to 4× faster with ViT-L.
Turns out, your ViT is secretly an image segmentation model. EoMT shows that architectural complexity isn't necessary. For segmentation, a plain Transformer is all you need.
EoMT is also available on the main
branch of Hugging Face Transformers. To install from source:
pip install git+https://github.com/huggingface/transformers
You can use EoMT for segmentation in just a few lines using the official 🤗 EoMT in Transformers:
import matplotlib.pyplot as plt
import requests
import torch
from PIL import Image
from transformers import EomtForUniversalSegmentation, AutoImageProcessor
model_id = "tue-mps/coco_panoptic_eomt_large_640"
processor = AutoImageProcessor.from_pretrained(model_id)
model = EomtForUniversalSegmentation.from_pretrained(model_id)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(images=image, return_tensors="pt")
with torch.inference_mode():
outputs = model(**inputs)
original_image_sizes = [(image.height, image.width)]
preds = processor.post_process_panoptic_segmentation(outputs, original_image_sizes)
plt.imshow(preds[0]["segmentation"])
plt.axis("off")
plt.title("Panoptic Segmentation")
plt.show()
If you don't have Conda installed, install Miniconda and restart your shell:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
Then create the environment, ac E791 tivate it, and install the dependencies:
conda create -n eomt python==3.13.2
conda activate eomt
python3 -m pip install -r requirements.txt
Weights & Biases (wandb) is used for experiment logging and visualization. To enable wandb, log in to your account:
wandb login
Download the datasets below depending on which datasets you plan to use.
You do not need to unzip any of the downloaded files.
Simply place them in a directory of your choice and provide that path via the --data.path
argument.
The code will read the .zip
files directly.
COCO
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
wget http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip
ADE20K
wget http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
wget http://sceneparsing.csail.mit.edu/data/ChallengeData2017/annotations_instance.tar
tar -xf annotations_instance.tar
zip -r -0 annotations_instance.zip annotations_instance/
rm -rf annotations_instance.tar
rm -rf annotations_instance
Cityscapes
wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=<your_username>&password=<your_password>&submit=Login' https://www.cityscapes-dataset.com/login/
wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1
wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3
🔧 Replace <your_username>
and <your_password>
with your actual Cityscapes login credentials.
To train EoMT from scratch, run:
python3 main.py fit \
-c configs/coco/panoptic/eomt_large_640.yaml \
--trainer.devices 4 \
--data.batch_size 4 \
--data.path /path/to/dataset
This command trains the EoMT-L
model with a 640×640 input size on COCO panoptic segmentation using 4 GPUs. Each GPU processes a batch of 4 images, for a total batch size of 16.
✅ Make sure the total batch size is devices × batch_size = 16
🔧 Replace /path/to/dataset
with the directory containing the dataset zip files.
This configuration takes ~6 hours on 4×NVIDIA H100 GPUs, each using ~26GB VRAM.
To fine-tune a pre-trained EoMT model, add:
--model.ckpt_path /path/to/pytorch_model.bin \
--model.load_ckpt_class_head False
🔧 Replace /path/to/pytorch_model.bin
with the path to the checkpoint to fine-tune.
--model.load_ckpt_class_head False
skips loading the classification head when fine-tuning on a dataset with different classes.
To evaluate a pre-trained EoMT model, run:
python3 main.py validate \
-c configs/coco/panoptic/eomt_large_640.yaml \
--model.network.masked_attn_enabled False \
--trainer.devices 4 \
--data.batch_size 4 \
--data.path /path/to/dataset \
--model.ckpt_path /path/to/pytorch_model.bin
This command evaluates the same EoMT-L
model using 4 GPUs with a batch size of 4 per GPU.
🔧 Replace /path/to/dataset
with the directory containing the dataset zip files.
🔧 Replace /path/to/pytorch_model.bin
with the path to the checkpoint to evaluate.
A notebook is available for quick inference and visualization with auto-downloaded pre-trained models.
FPS measured on NVIDIA H100, unless otherwise specified.
Config | Input size | FPS | PQ | Download |
---|---|---|---|---|
EoMT-S2x | 640×640 | 330 | 46.7 | Model Weights |
EoMT-B2x | 640×640 | 261 | 51.6 | Model Weights |
EoMT-L | 640×640 | 128 | 56.0 | Model Weights |
EoMT-g | 640×640 | 55 | 57.0 | Model Weights |
EoMT-7B | 640×640 | 32* | 58.4 | Model Weights |
ViT-Adapter-7B + M2F | 640×640 | 17* | 58.4 | - |
2x Longer training schedule. * FPS measured on NVIDIA B200.
Config | Input size | FPS | PQ | Download |
---|---|---|---|---|
EoMT-L | 1280×1280 | 30 | 58.3 | Model Weights |
EoMT-g | 1280×1280 | 12 | 59.2 | Model Weights |
Config | Input size | FPS | PQ | Download |
---|---|---|---|---|
EoMT-L | 640×640 | 128 | 50.6C | Model Weights |
EoMT-g | 640×640 | 55 | 51.3C | Model Weights |
Config | Input size | FPS | PQ | Download |
---|---|---|---|---|
EoMT-L | 1280×1280 | 30 | 51.7C | Model Weights |
EoMT-g | 1280×1280 | 12 | 52.8C | Model Weights |
C Models pre-trained on COCO panoptic segmentation. See above for how to load a checkpoint.
Config | Input size | FPS | mIoU | Download |
---|---|---|---|---|
EoMT-L | 1024×1024 | 25 | 84.2 | Model Weights |
Config | Input size | FPS | mIoU | Download |
---|---|---|---|---|
EoMT-L | 512×512 | 92 | 58.4 | Model Weights |
Config | Input size | FPS | mAP | Download |
---|---|---|---|---|
EoMT-L | 640×640 | 128 | 45.2* | Model Weights |
Config | Input size | FPS | mAP | Download |
---|---|---|---|---|
EoMT-L | 1280×1280 | 30 | 48.8* | Model Weights |
* mAP reported using pycocotools; TorchMetrics (used by default) yields ~0.7 lower.
If you find this work useful in your research, please cite it using the BibTeX entry below:
@inproceedings{kerssies2025eomt,
author = {Kerssies, Tommie and Cavagnero, Niccol\`{o} and Hermans, Alexander and Norouzi, Narges and Averta, Giuseppe and Leibe, Bastian and Dubbelman, Gijs and de Geus, Daan},
title = {Your ViT is Secretly an Image Segmentation Model},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2025},
}
This project builds upon code from the following libraries and repositories:
- Hugging Face Transformers (Apache-2.0 License)
- PyTorch Image Models (timm) (Apache-2.0 License)
- PyTorch Lightning (Apache-2.0 License)
- TorchMetrics (Apache-2.0 License)
- Mask2Former (Apache-2.0 License)
- Detectron2 (Apache-2.0 License)