Agent skill

attribution-patching

Gradient-based approximation to activation patching for scalable circuit analysis. Use when activation patching is too slow or when analyzing many components simultaneously.

Stars 163
Forks 31

Install this agent skill to your Project

npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/data/attribution-patching

SKILL.md

Attribution Patching

Attribution patching uses gradients to approximate activation patching results in a single backward pass, making it practical to analyze thousands of components simultaneously.

Core Idea

Instead of running separate forward passes for each component:

  1. Run clean and corrupted forward passes
  2. Compute gradients of the metric w.r.t. corrupted activations
  3. Multiply gradients by (clean - corrupted) activation differences

This linear approximation works when clean and corrupted runs are similar.

Mathematical Formula

attribution(component) = grad_corrupted(metric) * (clean_activation - corrupted_activation)

Setup

python
from nnsight import LanguageModel
import torch

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

correct_token = model.tokenizer(" John")["input_ids"][0]
incorrect_token = model.tokenizer(" Mary")["input_ids"][0]

def logit_diff(logits):
    return logits[0, -1, correct_token] - logits[0, -1, incorrect_token]

Basic Attribution Patching

python
n_layers = len(model.transformer.h)

clean_acts = []
corrupted_acts = []
corrupted_grads = []

# Clean forward pass - save activations
with model.trace(clean_prompt):
    for layer in model.transformer.h:
        act = layer.output[0]
        clean_acts.append(act.save())

# Corrupted forward + backward pass
with model.trace(corrupted_prompt):
    # Register intermediate values in forward order
    for layer in model.transformer.h:
        act = layer.output[0]
        act.requires_grad = True
        corrupted_acts.append(act.save())

    # Compute metric
    logits = model.lm_head.output
    metric = logit_diff(logits)

    # Access gradients in REVERSE order within backward context
    with metric.backward():
        for layer in reversed(model.transformer.h):
            corrupted_grads.insert(0, layer.output[0].grad.save())

# Compute attributions
attributions = []
for i in range(n_layers):
    clean = clean_acts[i].value
    corrupted = corrupted_acts[i].value
    grad = corrupted_grads[i].value

    # Attribution = grad * (clean - corrupted)
    attr = (grad * (clean - corrupted)).sum()
    attributions.append(attr.item())

attributions = torch.tensor(attributions)

Per-Position Attribution

python
seq_len = clean_acts[0].value.shape[1]
position_attrs = torch.zeros(n_layers, seq_len)

for layer_idx in range(n_layers):
    clean = clean_acts[layer_idx].value
    corrupted = corrupted_acts[layer_idx].value
    grad = corrupted_grads[layer_idx].value

    # Sum over hidden dimension only, keep position
    diff = clean - corrupted
    attr = (grad * diff).sum(dim=-1).squeeze()  # [seq_len]
    position_attrs[layer_idx] = attr

Attention Head Attribution

python
from einops import rearrange

n_heads = model.config.n_head
head_dim = model.config.n_embd // n_heads
head_attrs = torch.zeros(n_layers, n_heads)

# Collect clean attention outputs
clean_attn = []
with model.trace(clean_prompt):
    for layer in model.transformer.h:
        attn_out = layer.attn.c_proj.input[0][0]  # Before projection
        clean_attn.append(attn_out.save())

# Collect corrupted attention outputs and gradients
corrupted_attn = []
attn_grads = []

with model.trace(corrupted_prompt):
    # Register intermediate values in forward order
    for layer in model.transformer.h:
        attn_out = layer.attn.c_proj.input[0][0]
        attn_out.requires_grad = True
        corrupted_attn.append(attn_out.save())

    metric = logit_diff(model.lm_head.output)

    # Access gradients in REVERSE order within backward context
    with metric.backward():
        for layer in reversed(model.transformer.h):
            attn_grads.insert(0, layer.attn.c_proj.input[0][0].grad.save())

# Compute per-head attributions
for layer_idx in range(n_layers):
    clean = clean_attn[layer_idx].value
    corrupted = corrupted_attn[layer_idx].value
    grad = attn_grads[layer_idx].value

    # Reshape to [batch, seq, heads, head_dim]
    clean_heads = rearrange(clean, 'b s (h d) -> b s h d', h=n_heads)
    corrupted_heads = rearrange(corrupted, 'b s (h d) -> b s h d', h=n_heads)
    grad_heads = rearrange(grad, 'b s (h d) -> b s h d', h=n_heads)

    # Attribution per head
    diff = clean_heads - corrupted_heads
    attr = (grad_heads * diff).sum(dim=(0, 1, 3))  # Sum batch, seq, head_dim
    head_attrs[layer_idx] = attr

Efficient Batched Version

Process both prompts in a single forward pass using batching:

python
# Batch both prompts together in a single trace
all_acts = []
all_grads = []

with model.trace([clean_prompt, corrupted_prompt]):
    # Register intermediate values in forward order
    for layer in model.transformer.h:
        act = layer.output[0]
        act.requires_grad = True
        all_acts.append(act.save())

    logits = model.lm_head.output
    # Metric on corrupted (index 1)
    metric = logit_diff(logits[1:2])

    # Access gradients in REVERSE order within backward context
    with metric.backward():
        for layer in reversed(model.transformer.h):
            all_grads.insert(0, layer.output[0].grad.save())

# Split clean/corrupted and compute attributions
attributions = []
for i in range(n_layers):
    acts = all_acts[i].value
    grads = all_grads[i].value

    clean = acts[0:1]
    corrupted = acts[1:2]
    grad = grads[1:2]  # Gradient is only for corrupted

    attr = (grad * (clean - corrupted)).sum()
    attributions.append(attr.item())

Comparison with Activation Patching

Aspect Activation Patching Attribution Patching
Accuracy Exact Approximation
Speed O(n_components) forwards O(1) forward + backward
Memory Lower per run Higher (stores grads)
Best for Few components Many components

Validation

Compare attribution results against ground truth patching:

python
# Scatter plot: attribution vs actual patching effect
import matplotlib.pyplot as plt

plt.scatter(attributions, actual_patching_results)
plt.xlabel("Attribution Score")
plt.ylabel("Actual Patching Effect")
plt.title("Attribution vs Patching Correlation")
correlation = torch.corrcoef(torch.stack([attributions, actual_patching_results]))[0, 1]
plt.text(0.1, 0.9, f"r = {correlation:.3f}", transform=plt.gca().transAxes)

When to Use

  • Use attribution patching: Initial exploration, many components, large models
  • Use activation patching: Validating specific components, exact measurements needed
  • Combine both: Attribution for screening, patching for confirmation

Expand your agent's capabilities with these related and highly-rated skills.

Didn't find tool you were looking for?

Be as detailed as possible for better results