This repository provides tools and scripts for mechanistic interpretability of medical imaging models using Sparse Autoencoders (SAEs). It supports training, evaluation, and analysis of SAEs on neural network activations, with a focus on models such as CLIP and MedCLIP.
- 📂 Project Structure
- ⚙️ Installation
- 📊 Datasets
- 🧠 Models
- 💾 Computing Activations & SAE Features
- 🏋️ SAE Training
- 🗂️ Data Naming Convention
- 🚀 Quick Start: Training an SAE, Assigning Concepts, and Evaluating Interpretability
- 🔧 Useful Commands
- 📚 References
.
├── data/ # Data produced by the different scripts and steps of the scripts
│ ├── activations/ # Tensors of activations (advised naming : "model_dataset_split_additionalprocessing.pt")
│ ├── assign_concepts/ # CSV Tables of concept/neuron matchings
│ ├── probe_checkpoints/ # Tensors of probe weights trained with the script "train_probe.py"
│ ├── sae_checkpoints/ # Tensors of SAE weights (advised naming : "model_dataset_sae-architecture_remarks.pt")
│ ├── sae_features/ # Tensors of SAE features (or activations) (advised naming : "model_dataset_sae-architecture_remarks_sae-dataset_sae-split.pt")
│ └── vocabs/ # Txt vocabulary database
├── datasets/ # Datasets for training/evaluation (cc3m, chexpert, etc.)
├── dictionary_learning/ # Core dictionary learning package (SAE training, buffer, trainers) cloned from https://github.com/saprmarks/dictionary_learning/ and adapted for CLIP and MedCLIP models (see Section on Dictionnary_learning adaptation)
├── logs/ # Training and evaluation logs
├── results/ # Output results (plots, metrics, etc.)
├── src/ # Main source code for analysis, training, and plotting
│ ├── 1d_probe_training.py # Train probes on 1D features for interpretability
│ ├── assign_names.py # Assign human-readable names from a word vocabulary to SAE features
│ ├── assign_names_VLM.py # Assign names using the vision-language model MedGEMMA
│ ├── cc3m_training.py # Train SAEs on the CC3M dataset
│ ├── cc3m_training_BatchTopK.py # Train SAEs on CC3M with BatchTopK architecture
│ ├── chexpert_score.py # Evaluate CheXpert score (neuron-class correlation scores basically) for models/SAEs
│ ├── chexpert_training.py # Train SAEs on the CheXpert dataset
│ ├── compute_sweep_activations.py # Compute activations for sweep experiments
│ ├── compute_sweep_deadlatents.py # Analyze dead latents in sweep experiments
│ ├── config.py # Configuration for models, datasets, and paths
│ ├── data_preprocess.py # Data preprocessing utilities
│ ├── generate_medical_vocabulary.py # Generate medical vocabulary files from a given text corpus
│ ├── linearity_sanity_check.py # Sanity check for linearity in features by analyzing combined images (see Report)
│ ├── linearity_sanity_check_pairs.py # Linearity check for combined pairs
│ ├── linearity_sanity_check_negative_allowed.py # Linearity check allowing negative least square solving
│ ├── monosemanticity_score.py # Compute monosemanticity score for SAEs/Models activations (see https://arxiv.org/pdf/2504.02821v1)
│ ├── monosemanticity_score_actversion.py # Alternative monosemanticity score computation (adaptation of https://arxiv.org/abs/2501.18052)
│ ├── naming_eval.py # Evaluate naming quality for SAE features with MedGemma (see Report)
│ ├── neuron_class_correlation_comparaison.py # Compare neuron-class correlations
│ ├── plot_feature_sparsity_distribution.py # Plot feature sparsity distributions
│ ├── plot_naming_eval.py # Plot results of naming evaluation (score distribution)
│ ├── plot_neuron_class_correlation.py # Plot neuron-class correlation results (correlation matrix)
│ ├── plot_top_activating_images.py # Visualize top activating images for features
│ ├── save_activations_or_features.py # Save activations or SAE features to disk
│ ├── sweep.py # Run parameter sweeps for training/evaluation
│ ├── sweep_config.py # Configuration for sweep experiments
│ ├── train_probe.py # Train probes on SAE or model features
│ └── utils.py # General utility functions such as data process, model import, etc..
├── vocab/ # Vocabulary files
├── build.sh # Build script
├── Dockerfile* # Docker configurations
├── requirements.txt # Python dependencies
├── README.md # (This file)
└── LICENSE
Install dependencies:
pip install -r requirements.txtInstall dictionary_learning:
pip install -e dictionary_learning/Install transformers version 4.24.0:
pip install transformers==4.24.0Fix current medclip library issue :
Add strict=False in this file: https://github.com/RyanWangZf/MedCLIP/blob/main/medclip/modeling_medclip.py#L185 (See : RyanWangZf/MedCLIP#37)
Supported datasets:
["chexpert", "cc3m", "mimic-cxr", "cifar100", "cifar10", "places365"]
- Torch datasets (
cifar100,cifar10,places365) → automatically downloaded. - External datasets require manual setup:
- CC3M:
datasets/cc3m/withcc3m-train-{0000..0575}.tarandcc3m-validation-{0000..0015}.tar - CheXpert:
datasets/chexpert/with CSVs (chexpert_train.csv,chexpert_valid.csv,chexpert_train_shuffled.csvoptional). Must include columnPath(image paths) + one column per label (0/1). - MIMIC-CXR: downloaded via Hugging Face (authentication may be required).
- CC3M:
Usage: Select dataset via --dataset tag in scripts. Some scripts are dataset-specific (e.g., chexpert_score.py).
The list of supported datasets can be easily expended by modifying "utils.py" and "data_preprocess.py".
Supported models:
["MedCLIP-RN50", "CLIP-RN50", "MedCLIP-ViT"]
- CLIP models: from OpenAI CLIP
- MedCLIP models: from MedCLIP repo
Extend by editing src/utils.py + src/config.py, and adding buffers in dictionary_learning/buffer.py.
Activations and SAE features can be computed either:
- Independently using
src/save_activations_or_features.py - On the fly within other scripts that require them
Because these computations are time-consuming, the workflow is designed so that they are performed once, saved in the data/ directory (under data/activations/ or data/sae_features/), and then reused automatically.
Each script checks whether the requested activations or features already exist (see Data Naming Convention):
- ✅ If available → they are loaded from disk
- ❌ If not → they are computed and saved for future use
You can also provide custom precomputed data with:
--load_activations_path→ specify a path to activations--load_features_path→ specify a path to SAE features
This enables you to preprocess data manually and still use it seamlessly in the scripts.
The repository implements buffered loading via FeatureBuffer and CLIPActivationBuffer (customized from the dictionary_learning library).
- This avoids loading entire datasets into memory, reducing RAM usage.
- Some scripts, however, bypass this and load all data directly into memory by iterating over the buffer.
SAEs are trained with our customized version of dictionary_learning/training.py.
Example scripts:
src/cc3m_training.pysrc/chexpert_training.pysrc/sweep.py,src/sweep_config.py
Available SAE architectures:
["Standard", "BatchTopKSAE", "StandardAprilUpdate", "MatryoshkaBatchTopKSAE"]
Details: dictionary_learning repo.
- SAE features:
model_dataset_sae-arch_remarks_sae-dataset_sae-split.pt - SAE checkpoints:
model_dataset_sae-arch_remarks.pt - Activations:
model_dataset_split_additionalprocessing.pt
Before starting, make sure you have completed the installation steps.
This guide walks through the process step by step, using intermediate outputs so that errors can be isolated easily.
Example:
MedCLIP-RN50- If using a Torch dataset (e.g.,
cifar100,cifar10,places365), it will download automatically. - If using CheXpert, download the dataset (e.g., Kaggle CheXpert) and check that paths in the CSV files point to the correct image locations.
From the project root, compute train split activations:
python src/save_activations_or_features.py --model_name MedCLIP-RN50 --split train --dataset chexpertThis saves activations to:
data/activations/MedCLIP-RN50_chexpert_train.pt
For the validation split:
python src/save_activations_or_features.py --model_name MedCLIP-RN50 --split val --dataset chexpertUse the training scripts provided (e.g., src/chexpert_training.py or src/cc3m_training.py).
- By default,
chexpert_training.pyuses CheXpert + MedCLIP-RN50 with the Standard SAE architecture. - Modify training parameters (e.g.,
batch_size,learning_rate,SAE architecture) directly in the script. - To change SAE type, update the trainer imports:
from dictionary_learning.trainers.standard import StandardTrainer from dictionary_learning.trainer_config import StandardTrainerConfig from dictionary_learning.dictionary import AutoEncoder
- Refer to the
dictionary_learninglibrary and sweep scripts for customization examples.
Run training:
python src/chexpert_training.pyCheckpoints will be saved in:
data/sae_checkpoints/
Once training is done, compute SAE features with:
python src/save_activations_or_features.py --model_name MedCLIP-RN50 --split val --dataset chexpert --sae_path "MedCLIP-RN50_chexpert_standard.pt" --sae_type_name "Standard"Output is saved in:
data/sae_features/MedCLIP-RN50_chexpert_standard_chexpert_val.pt
👉 You can also run this for other datasets/splits. If the corresponding activations are missing, both activations and SAE features will be computed automatically.
Two methods are supported:
- Using MedGEMMA VLM → analyzes top/low activating images (
src/assign_names_VLM.py) - Using a vocabulary → matches words/sentences via cosine similarity in CLIP space (
src/assign_names.py)
Example with MedGEMMA:
python src/assign_names_VLM.py --sae_path MedCLIP-RN50_chexpert_standard.pt --sae_type_name Standard --model_name MedCLIP-RN50 --split val --dataset chexpertThis saves a CSV in:
data/assign_concepts/MedCLIP-RN50_chexpert_standard_reportsFalse.csv
The file includes:
neuron→ feature indexconcept→ assigned namedescription→ explanation
MedGemma model page
Evaluation uses MedGemma as a classifier:
- It predicts whether an image activates a neuron based on the concept description.
- Results are compared against ground truth activations.
Run:
python src/assign_names_VLM.py --sae_path MedCLIP-RN50_chexpert_standard.pt --sae_type_name Standard --model_name MedCLIP-RN50 --split val --dataset chexpert --assigned_concepts MedCLIP-RN50_chexpert_standard_reportsFalse.csv- Plot feature sparsity distribution
python src/plot_feature_sparsity_distribution.py --sae_path clip-rn50_cc3m_repo.pt --sae_type_name Standard --model_name CLIP-RN50 --dataset cc3m --split val- Visualize top activating images
python src/plot_top_activating_images.py --sae_path clip-rn50_cc3m_repo.pt --sae_type_name Standard --model_name CLIP-RN50 --dataset places365 --split val- Compute CheXpert score
python src/chexpert_score.py --model_name MedCLIP-RN50 --split trainval --dataset chexpert --sae_path medclip-rn50_chexpert_standard.pt --sae_type_name Standard