8000 GitHub - Prisma-Multimodal/ViT-Prisma: ViT Prisma is a mechanistic interpretability library for Vision and Video Transformers (ViTs).
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Prisma-Multimodal/ViT-Prisma

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vit Prisma Logo

Prisma: An Open Source Toolkit for Mechanistic Interpretability in Vision and Video

Prisma contains code for vision and video mechanistic interpretability, including activation caching and SAE training. We support a variety of vision/video models from Huggingface and OpenCLIP. This library was originally made by Sonia Joseph (see a full list of contributors here).

Mechanistic interpretability is broadly split into two parts: circuit-analysis and sparse autoencoders (SAEs). Circuit-analysis finds the causal links between internal components of the model and primarily relies on activation caching. SAEs are like a more fine-grained "primitive" that you can use to examine intermediate activations. Prisma has the infrastructure to do both.

We also include a suite of open source SAEs for all layers of CLIP and DINO, including transcoders for all layers of CLIP, that you can download from Huggingface.

For more details, check out our whitepaper Prisma: An Open Source Toolkit for Mechanistic Interpretability in Vision and Video. Also, check out the original Less Wrong post here.

Table of Contents

Installation

We recommend installing from source:

git clone https://github.com/soniajoseph/ViT-Prisma
cd ViT-Prisma
pip install -e .

Models Supported

We support most vision/video transformers loaded from OpenCLIP and Huggingface, including ViTs, CLIP, DINO, and JEPA, with a few exceptions (e.g. if the architecture is substantially different).

For a list of model names, check out our model registry here.

To load a model:

from vit_prisma.models.model_loader import load_hooked_model

model_name = "open-clip:laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K
model = load_hooked_model(model_name)
model.to('cuda') # Move to cuda if available

SAE Pretrained Weights, Training, and Evaluation Code

SAE Demo Notebooks

Here are notebooks to load, train, and evaluate SAEs on the supported models:

To load an SAE (see notebook for details):

from huggingface_hub import hf_hub_download, list_repo_files
from vit_prisma.sae import SparseAutoencoder

# Step 1: Download SAE weights and config from Hugginface
repo_id = "Prisma-Multimodal/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-1e-05" # Change this to your chosen SAE. See docs/sae_table.md for a full list of SAEs.
sae_path = hf_hub_download(repo_id, file_name="weights.pt") # file_name is usually weights.pt but may have slight naming variation. See the original HF repo for the exact file name
hf_hub_download(repo_id, config_name="config.json")

# Step 2: Load the pretrained SAE weights from the downloaded path
sae = SparseAutoencoder.load_from_pretrained(sae_path) # This now automatically gets config.json and converts into the VisionSAERunnerConfig object

Suite of Pretrained Vision SAE Weights

For a full list of SAEs for all layers, including CLIP top k, CLIP transcoders, and DINO SAEs, see here.

We recommend starting with the vanilla CLIP SAEs, which are the highest quality. If you are just getting started with steering CLIP's output, we recommend using the Layer 11 resid-post SAE.

CLIP Vanilla SAEs (All Patches)

Model Layer Sublayer l1 coeff. % Explained var. Avg L0 Avg CLS L0 Cos sim Recon cos sim CE Recon CE Zero abl CE % CE recovered % Alive features
link 0 mlp_out 1e-5 98.7 604.44 36.92 0.994 0.998 6.762 6.762 6.779 99.51 100
link 0 resid_post 1e-5 98.6 1110.9 40.46 0.993 0.988 6.762 6.763 6.908 99.23 100
link 1 mlp_out 1e-5 98.4 1476.8 97.82 0.992 0.994 6.762 6.762 6.889 99.40 100
link 1 resid_post 1e-5 98.3 1508.4 27.39 0.991 0.989 6.762 6.763 6.908 99.02 100
link 2 mlp_out 1e-5 98.0 1799.7 376.0 0.992 0.998 6.762 6.762 6.803 99.44 100
link 2 resid_post 5e-5 90.6 717.84 10.11 0.944 0.960 6.762 6.767 6.908 96.34 100
link 3 mlp_out 1e-5 98.1 1893.4 648.2 0.992 0.999 6.762 6.762 6.784 99.54 100
link 3 resid_post 1e-5 98.1 2053.9 77.90 0.989 0.996 6.762 6.762 6.908 99.79 100
link 4 mlp_out 1e-5 98.1 1901.2 1115.0 0.993 0.999 6.762 6.762 6.786 99.55 100
link 4 resid_post 1e-5 98.0 2068.3 156.7 0.989 0.997 6.762 6.762 6.908 99.74 100
link 5 mlp_out 1e-5 98.2 1761.5 1259.0 0.993 0.999 6.762 6.762 6.797 99.76 100
link 5 resid_post 1e-5 98.1 1953.8 228.5 0.990 0.997 6.762 6.762 6.908 99.80 100
link 6 mlp_out 1e-5 98.3 1598.0 1337.0 0.993 0.999 6.762 6.762 6.789 99.83 100
link 6 resid_post 1e-5 98.2 1717.5 321.3 0.991 0.996 6.762 6.762 6.908 99.93 100
link 7 mlp_out 1e-5 98.2 1535.3 1300.0 0.992 0.999 6.762 6.762 6.796 100.17 100
link 7 resid_post 1e-5 98.2 1688.4 494.3 0.991 0.995 6.762 6.761 6.908 100.24 100
link 8 mlp_out 1e-5 97.8 1074.5 1167.0 0.990 0.998 6.762 6.761 6.793 100.57 100
link 8 resid_post 1e-5 98.2 1570.8 791.3 0.991 0.992 6.762 6.761 6.908 100.41 100
link 9 mlp_out 1e-5 97.6 856.68 1076.0 0.989 0.998 6.762 6.762 6.792 100.28 100
link 9 resid_post 1e-5 98.2 1533.5 1053.0 0.991 0.989 6.762 6.761 6.908 100.32 100
link 10 mlp_out 1e-5 98.1 788.49 965.5 0.991 0.998 6.762 6.762 6.772 101.50 99.80
link 10 resid_post 1e-5 98.4 1292.6 1010.0 0.992 0.987 6.762 6.760 6.908 100.83 99.99
link 11 mlp_out 5e-5 89.7 748.14 745.5 0.972 0.993 6.762 6.759 6.768 135.77 100
link 11 resid_post 1e-5 98.4 1405.0 1189.0 0.993 0.987 6.762 6.765 6.908 98.03 99.99

DINO (Vanilla, all patches)

Model Layer Sublayer Avg L0 % Explained var. Avg CLS L0 Cos sim CE Recon CE Zero abl CE % CE Recovered
link 0 resid_post 507 98 347 0.95009 1.885033 1.936518 7.2714 99.04
link 1 resid_post 549 95 959 0.93071 1.885100 1.998274 7.2154 97.88
link 2 resid_post 812 95 696 0.95600 1.885134 2.006115 7.2015 97.72
link 3 resid_post 989 95 616 0.96315 1.885131 1.961913 7.2068 98.56
link 4 resid_post 876 99 845 0.99856 1.885224 1.883169 7.1636 100.04
link 5 resid_post 1001 98 889 0.99129 1.885353 1.875520 7.1412 100.19
link 6 resid_post 962 99 950 0.99945 1.885239 1.872594 7.1480 100.24
link 7 resid_post 1086 98 1041 0.99341 1.885371 1.869443 7.1694 100.30
link 8 resid_post 530 90 529 0.94750 1.885511 1.978638 7.1315 98.22
link 9 resid_post 1105 99 1090 0.99541 1.885341 1.894026 7.0781 99.83
link 10 resid_post 835 99 839 0.99987 1.885371 1.884487 7.3606 100.02
link 11 resid_post 1085 99 1084 0.99673 1.885370 1.911608 6.9078 99.48

CLIP Transcoders

CLIP Top-K transcoder performance metrics for all patches.

Model Layer Block % Explained var. k Avg CLS L0 Cos sim CE Recon CE Zero abl CE % CE recovered
link 0 MLP 96 768 767 0.9655 6.7621 6.7684 6.8804 94.68
link 1 MLP 94 256 255 0.9406 6.7621 6.7767 6.8816 87.78
link 2 MLP 93 1024 475 0.9758 6.7621 6.7681 6.7993 83.92
link 3 MLP 90 1024 825 0.9805 6.7621 6.7642 6.7999 94.42
link 4 MLP 76 512 29 0.9830 6.7621 6.7636 6.8080 96.76
link 5 MLP 91 1024 1017 0.9784 6.7621 6.7643 6.8296 96.82
link 6 MLP 94 1024 924 0.9756 6.7621 6.7630 6.8201 98.40
link 7 MLP 97 1024 1010 0.9629 6.7621 6.7631 6.8056 97.68
link 8 MLP 98 1024 1023 0.9460 6.7621 6.7630 6.8017 97.70
link 9 MLP 98 1024 1023 0.9221 6.7621 6.7630 6.7875 96.50
link 10 MLP 97 1024 1019 0.9334 6.7621 6.7636 6.7860 93.95

More details are in our whitepaper here. For more SAEs, including CLS-only and spatial patch-only variants, see the SAE table. We've also visualized some Prisma SAEs here.

Basic Mechanistic Interpretability

An earlier version of Prisma included features for basic mechanistic interpretability, including the logit lens and attention head visualizations. In addition to the tutorial notebooks below, you can also check out this corresponding talk on some of these techniques.

  1. Main ViT Demo - Overview of main mechanistic interpretability technique on a ViT, including direct logit attribution, attention head visualization, and activation patching. The activation patching switches the net's prediction from tabby cat to Border collie with a minimum ablation.
  2. Emoji Logit Lens - Deeper dive into layer- and patch-level predictions with interactive plots.
  3. Interactive Attention Head Tour - Deeper dive into the various types of attention heads a ViT contains with interactive JavaScript.

Features

For a demo of Prisma's mech interp features, including the visualizations below with interactivity, check out the demo notebooks above.

Attention head visualization

Logo Image 1 Logo Image 2 Logo Image 3

Activation patching

Direct logit attribution

Emoji logit lens

Custom Models & Checkpoints

ImageNet-1k classification checkpoints (patch size 32)

All models include training checkpoints, in case you want to analyze training dynamics.

This larger patch size ViT has inspectable attention heads; else the patch size 16 attention heads are too large to easily render in JavaScript.

Size NumLayers Attention+MLP AttentionOnly Model Link
tiny 3 0.22 | 0.42 N/A Attention+MLP

ImageNet-1k classification checkpoints (patch size 16)

The detailed training logs and metrics can be found here. These models were trained by Yash Vadi.

Table of Results

Accuracy [ <Acc> | <Top5 Acc> ]

Size NumLayers Attention+MLP AttentionOnly Model Link
tiny 1 0.16 | 0.33 0.11 | 0.25 AttentionOnly, Attention+MLP
base 2 0.23 | 0.44 0.16 | 0.34 AttentionOnly, Attention+MLP
small 3 0.28 | 0.51 0.17 | 0.35 AttentionOnly, Attention+MLP
medium 4 0.33 | 0.56 0.17 | 0.36 AttentionOnly, Attention+MLP

Contributors

This library was originally founded by Sonia Joseph, alongside fantastic contributors: Praneet Suresh, Yash Vadi, Rob Graham, Lorenz Hufe, Edward Stevinson, and Ethan Goldfarb, and more coming soon. You learn more about our contributions on our Contributors page (coming soon). Also, check out our whitepaper.

Thank you to Leo Gao, Joseph Bloom, Lee Sharkey, Neel Nanda, and Yossi Gandelsman for the feedback and discussions at the beginning of this repo's development.

We welcome new contributors. Check out our contributing guidelines here and our open Issues.

Citation

Please cite this repository when used in papers or research projects. Thank you for supporting the community!

@misc{joseph2025prismaopensourcetoolkit,
      title={Prisma: An Open Source Toolkit for Mechanistic Interpretability in Vision and Video}, 
      author={Sonia Joseph and Praneet Suresh and Lorenz Hufe and Edward Stevinson and Robert Graham and Yash Vadi and Danilo Bzdok and Sebastian Lapuschkin and Lee Sharkey and Blake Aaron Richards},
      year={2025},
      eprint={2504.19475},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2504.19475}, 
}
@misc{joseph2023vit,
  author = {Sonia Joseph},
  title = {ViT Prisma: A Mechanistic Interpretability Library for Vision Transformers},
  year = {2023},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/soniajoseph/vit-prisma}}
}

License

We have an MIT license here.

About

ViT Prisma is a mechanistic interpretability library for Vision and Video Transformers (ViTs).

Resources

License

Code of conduct

Stars

Watchers

Forks

Packages

No packages published

Contributors 9

0