Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions overcomplete/sae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .factory import EncoderFactory
from .jump_sae import JumpSAE, jump_relu, heaviside
from .topk_sae import TopKSAE
from .rasae import RATopKSAE, RAJumpSAE
from .qsae import QSAE
from .batchtopk_sae import BatchTopKSAE
from .mp_sae import MpSAE
Expand Down
54 changes: 27 additions & 27 deletions overcomplete/sae/jump_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class JumpReLU(torch.autograd.Function):
JumpReLU activation function with pseudo-gradient for threshold.
"""
@staticmethod
def forward(ctx, x, threshold, kernel_fn, bandwith):
def forward(ctx, x, threshold, kernel_fn, bandwidth):
"""
Forward pass of the JumpReLU activation function.
Save the necessary variables for the backward pass.
Expand All @@ -28,11 +28,11 @@ def forward(ctx, x, threshold, kernel_fn, bandwith):
Threshold tensor, learnable parameter.
kernel_fn : callable
Kernel function.
bandwith : float
Bandwith of the kernel.
bandwidth : float
Bandwidth of the kernel.
"""
ctx.save_for_backward(x, threshold)
ctx.bandwith = bandwith
ctx.bandwidth = bandwidth
ctx.kernel_fn = kernel_fn

output = x.clone()
Expand All @@ -52,7 +52,7 @@ def backward(ctx, grad_output):
Gradient of the loss w.r.t. the output.
"""
x, threshold = ctx.saved_tensors
bandwith = ctx.bandwith
bandwidth = ctx.bandwidth
kernel_fn = ctx.kernel_fn

# gradient w.r.t. input (normal gradient)
Expand All @@ -61,11 +61,11 @@ def backward(ctx, grad_output):

# pseudo-gradient w.r.t. threshold parameters
delta = x - threshold
kernel_values = kernel_fn(delta, bandwith)
kernel_values = kernel_fn(delta, bandwidth)

# @tfel: we have a singularity at threshold=0, thus the
# re-parametrization trick in JumpSAE class
grad_threshold = - (threshold / bandwith) * kernel_values * grad_output
grad_threshold = - (threshold / bandwidth) * kernel_values * grad_output
grad_threshold = grad_threshold.sum(0)

return grad_input, grad_threshold, None, None
Expand All @@ -81,7 +81,7 @@ class HeavisidePseudoGradient(torch.autograd.Function):
The pseudo-gradient is used to approximate the gradient at the threshold.
"""
@staticmethod
def forward(ctx, x, threshold, kernel_fn, bandwith):
def forward(ctx, x, threshold, kernel_fn, bandwidth):
"""
Forward pass of the Heaviside step function.
Save the necessary variables for the backward pass.
Expand All @@ -94,11 +94,11 @@ def forward(ctx, x, threshold, kernel_fn, bandwith):
Threshold tensor, learnable parameter.
kernel_fn : callable
Kernel function.
bandwith : float
Bandwith of the kernel.
bandwidth : float
Bandwidth of the kernel.
"""
ctx.save_for_backward(x, threshold)
ctx.bandwith = bandwith
ctx.bandwidth = bandwidth
ctx.kernel_fn = kernel_fn

output = (x > threshold).float()
Expand All @@ -117,22 +117,22 @@ def backward(ctx, grad_output):
Gradient of the loss w.r.t. the output.
"""
x, threshold = ctx.saved_tensors
bandwith = ctx.bandwith
bandwidth = ctx.bandwidth
kernel_fn = ctx.kernel_fn

delta = x - threshold
kernel_values = kernel_fn(delta, bandwith)
kernel_values = kernel_fn(delta, bandwidth)

# see the paper for the formula
grad_threshold = - (1 / bandwith) * kernel_values * grad_output
grad_threshold = - (1 / bandwidth) * kernel_values * grad_output
grad_threshold = grad_threshold.sum(0)

grad_input = torch.zeros_like(x)

return grad_input, grad_threshold, None, None


def jump_relu(x, threshold, kernel_fn, bandwith):
def jump_relu(x, threshold, kernel_fn, bandwidth):
"""
Apply the JumpReLU activation function to the input tensor.

Expand All @@ -144,18 +144,18 @@ def jump_relu(x, threshold, kernel_fn, bandwith):
Threshold tensor, learnable parameter.
kernel_fn : callable
Kernel function.
bandwith : float
Bandwith of the kernel.
bandwidth : float
Bandwidth of the kernel.

Returns
-------
torch.Tensor
Output tensor.
"""
return JumpReLU.apply(x, threshold, kernel_fn, bandwith)
return JumpReLU.apply(x, threshold, kernel_fn, bandwidth)


def heaviside(x, threshold, kernel_fn, bandwith):
def heaviside(x, threshold, kernel_fn, bandwidth):
"""
Apply the Heaviside step function to the input tensor.

Expand All @@ -167,15 +167,15 @@ def heaviside(x, threshold, kernel_fn, bandwith):
Threshold tensor, learnable parameter.
kernel_fn : callable
Kernel function.
bandwith : float
Bandwith of the kernel.
bandwidth : float
Bandwidth of the kernel.

Returns
-------
torch.Tensor
Output tensor.
"""
return HeavisidePseudoGradient.apply(x, threshold, kernel_fn, bandwith)
return HeavisidePseudoGradient.apply(x, threshold, kernel_fn, bandwidth)


class JumpSAE(SAE):
Expand Down Expand Up @@ -222,8 +222,8 @@ class JumpSAE(SAE):
- 'quartic'
- 'silverman'
- 'cauchy'.
bandwith : float, optional
Bandwith of the kernel, by default 1e-3.
bandwidth : float, optional
Bandwidth of the kernel, by default 1e-3.
encoder_module : nn.Module or string, optional
Custom encoder module, by default None.
If None, a simple Linear + BatchNorm default encoder is used.
Expand Down Expand Up @@ -257,7 +257,7 @@ class JumpSAE(SAE):
'cauchy': cauchy_kernel
}

def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwith=1e-3,
def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwidth=1e-3,
encoder_module=None, dictionary_params=None, device='cpu'):
assert isinstance(encoder_module, (str, nn.Module, type(None)))
assert isinstance(input_shape, (int, tuple, list))
Expand All @@ -267,7 +267,7 @@ def __init__(self, input_shape, nb_concepts, kernel='silverman', bandwith=1e-3,
dictionary_params, device)

self.kernel_fn = self._KERNELS[kernel]
self.bandwith = torch.tensor(bandwith, device=device)
self.bandwidth = torch.tensor(bandwidth, device=device)

# exp(-3) make the thresholds start around 0.05
self.thresholds = nn.Parameter(torch.ones(nb_concepts, device=device)*(-3.0), requires_grad=True)
Expand Down Expand Up @@ -297,7 +297,7 @@ def encode(self, x):
# see paper, appendix J
codes = torch.relu(pre_codes)

codes = jump_relu(codes, exp_thresholds, bandwith=self.bandwith,
codes = jump_relu(codes, exp_thresholds, bandwidth=self.bandwidth,
kernel_fn=self.kernel_fn)

return pre_codes, codes
110 changes: 110 additions & 0 deletions overcomplete/sae/rasae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Module for Relaxed Archetypal SAE implementations.
For the implementation of the Relaxed Archetypal Dictionary, see archetypal_dictionary.py.
"""

import torch
from torch import nn

from .topk_sae import TopKSAE
from .jump_sae import JumpSAE
from .archetypal_dictionary import RelaxedArchetypalDictionary


class RATopKSAE(TopKSAE):
"""
Relaxed Archetypal TopK SAE.

This class implements a TopK SAE that utilizes a Relaxed Archetypal Dictionary.
The dictionary atoms are initialized and constrained to be convex combinations
of data points.

For more information, see:
- "Archetypal SAE: Adaptive and Stable Dictionary Learning for Concept Extraction in
Large Vision Models" by T. Fel et al., ICML 2025 (https://arxiv.org/abs/2502.12892).

Parameters
----------
input_shape : int
Dimensionality of the input data (excluding the batch dimension).
nb_concepts : int
Number of dictionary atoms (concepts).
points : torch.Tensor
The data points used to initialize/define the archetypes.
Shape should be (num_points, input_shape).
top_k : int
Number of top activations to keep in the latent representation.
By default, 10% sparsity is used.
delta : float, optional
Delta parameter for the archetypal dictionary, by default 1.0.
use_multiplier : bool, optional
Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter
is 3 then the dictionary atoms are all on the ball of radius 3). By default True.
**kwargs : dict, optional
Additional arguments passed to the parent TopKSAE (e.g., encoder_module, device).
"""

def __init__(self, input_shape, nb_concepts, points, top_k=None, delta=1.0, use_multiplier=True, **kwargs):
assert isinstance(input_shape, int), "RATopKSAE input_shape must be an integer."

super().__init__(input_shape=input_shape, nb_concepts=nb_concepts,
top_k=top_k, **kwargs)

# enforce archetypal dictionary after the init of the parent class
self.dictionary = RelaxedArchetypalDictionary(
in_dimensions=input_shape,
nb_concepts=nb_concepts,
points=points,
delta=delta,
use_multiplier=use_multiplier,
device=self.device
)


class RAJumpSAE(JumpSAE):
"""
Relaxed Archetypal Jump SAE.

This class implements a Jump SAE that utilizes a Relaxed Archetypal Dictionary.
The dictionary atoms are initialized and constrained to be convex combinations
of data points.

For more information, see:
- "Archetypal SAE: Adaptive and Stable Dictionary Learning for Concept Extraction in
Large Vision Models" by T. Fel et al., ICML 2025 (https://arxiv.org/abs/2502.12892).

Parameters
----------
input_shape : int
Dimensionality of the input data (excluding the batch dimension).
nb_concepts : int
Number of dictionary atoms (concepts).
points : torch.Tensor
The data points used to initialize/define the archetypes.
Shape should be (num_points, input_shape).
bandwidth : float, optional
Bandwidth parameter for the Jump SAE kernel, by default 1e-3.
delta : float, optional
Delta parameter for the archetypal dictionary, by default 1.0.
use_multiplier : bool, optional
Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter
is 3 then the dictionary atoms are all on the ball of radius 3). By default True.
**kwargs : dict, optional
Additional arguments passed to the parent JumpSAE (e.g., encoder_module, device).
"""

def __init__(self, input_shape, nb_concepts, points, bandwidth=1e-3, delta=1.0, use_multiplier=True, **kwargs):
assert isinstance(input_shape, int), "RAJumpSAE input_shape must be an integer."

super().__init__(input_shape=input_shape, nb_concepts=nb_concepts,
bandwidth=bandwidth, **kwargs)

# enforce archetypal dictionary after the init of the parent class
self.dictionary = RelaxedArchetypalDictionary(
in_dimensions=input_shape,
nb_concepts=nb_concepts,
points=points,
delta=delta,
use_multiplier=use_multiplier,
device=self.device
)
32 changes: 24 additions & 8 deletions tests/sae/test_base_sae.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import pytest

import torch
from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE
from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE
from overcomplete.sae.modules import TieableEncoder

from ..utils import epsilon_equal

all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE]
all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE, RATopKSAE, RAJumpSAE]


def get_sae_kwargs(sae_class, input_size, nb_concepts, device):
"""Return specific kwargs required for certain SAE classes."""
kwargs = {}
# archetypal SAEs require 'points'
if sae_class in [RATopKSAE, RAJumpSAE]:
kwargs['points'] = torch.randn(nb_concepts * 2, input_size, device=device)
return kwargs


def test_dictionary_layer():
Expand All @@ -23,7 +32,9 @@ def test_dictionary_layer():
def test_sae(sae_class):
input_size = 10
nb_concepts = 5
model = sae_class(input_size, nb_concepts)

extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu')
model = sae_class(input_size, nb_concepts, **extra_kwargs)

x = torch.randn(3, input_size)
output = model(x)
Expand All @@ -43,13 +54,15 @@ def test_sae_device(sae_class):
input_size = 10
nb_components = 5

model = sae_class(input_size, nb_components, device='meta')
extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_components, device='meta')
model = sae_class(input_size, nb_components, device='meta', **extra_kwargs)

# ensure dictionary is on the meta device
dictionary = model.get_dictionary()
assert dictionary.device.type == 'meta'

model = sae_class(input_size, nb_components, device='cpu')
extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_components, device='cpu')
model = sae_class(input_size, nb_components, device='cpu', **extra_kwargs)

# ensure dictionary is on the cpu device
dictionary = model.get_dictionary()
Expand Down Expand Up @@ -111,7 +124,8 @@ def test_sae_tied_untied(sae_class):
input_size = 10
nb_concepts = 5

model = sae_class(input_size, nb_concepts)
extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu')
model = sae_class(input_size, nb_concepts, **extra_kwargs)

# Tie weights
model.tied()
Expand All @@ -130,7 +144,8 @@ def test_sae_tied_forward(sae_class):
input_size = 10
nb_concepts = 5

model = sae_class(input_size, nb_concepts)
extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu')
model = sae_class(input_size, nb_concepts, **extra_kwargs)
model.tied()

x = torch.randn(3, input_size)
Expand All @@ -146,7 +161,8 @@ def test_sae_untied_copy_weights(sae_class):
input_size = 10
nb_concepts = 5

model = sae_class(input_size, nb_concepts)
extra_kwargs = get_sae_kwargs(sae_class, input_size, nb_concepts, device='cpu')
model = sae_class(input_size, nb_concepts, **extra_kwargs)
model.tied()

# Get dictionary before untying
Expand Down
Loading