Agent skill

activation-patching

Causal intervention via activation patching to identify important model components. Use when determining which layers, heads, or positions are causally responsible for model behavior.

Stars 163
Forks 31

Install this agent skill to your Project

npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/data/activation-patching

SKILL.md

Activation Patching

Activation patching is a causal intervention technique that identifies which model components are responsible for specific behaviors by swapping activations between different inputs.

Core Concept

  1. Clean run: Run model on prompt that produces desired behavior
  2. Corrupted run: Run on modified prompt that changes the behavior
  3. Patch: Replace corrupted activations with clean ones, measure if behavior is restored

If patching a component restores the clean behavior, that component is causally important.

Basic Setup

python
from nnsight import LanguageModel
import torch

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

# Indirect Object Identification (IOI) task
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

# Target tokens
correct_token = model.tokenizer(" John")["input_ids"][0]   # Clean answer
incorrect_token = model.tokenizer(" Mary")["input_ids"][0]  # Corrupted answer

Metric: Logit Difference

python
def logit_diff(logits, correct_idx, incorrect_idx):
    """Measure how much model prefers correct over incorrect token."""
    return (logits[0, -1, correct_idx] - logits[0, -1, incorrect_idx]).item()

Three-Run Patching Pattern

python
n_layers = len(model.transformer.h)
results = torch.zeros(n_layers)

# Run 1: Clean - save activations
with model.trace(clean_prompt):
    clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]
    clean_logits = model.lm_head.output.save()

# Run 2: Corrupted baseline
with model.trace(corrupted_prompt):
    corrupted_logits = model.lm_head.output.save()

# Runs 3+: Patch each layer (separate forward passes)
for layer_idx in range(n_layers):
    with model.trace(corrupted_prompt):
        # Replace corrupted activation with clean
        model.transformer.h[layer_idx].output[0][:] = clean_hiddens[layer_idx]
        patched_logits = model.lm_head.output.save()
    results[layer_idx] = logit_diff(patched_logits.value, correct_token, incorrect_token)

# Normalize results
clean_diff = logit_diff(clean_logits.value, correct_token, incorrect_token)
corrupted_diff = logit_diff(corrupted_logits.value, correct_token, incorrect_token)
normalized = (results - corrupted_diff) / (clean_diff - corrupted_diff)

Position-Specific Patching

Patch only specific token positions:

python
seq_len = len(model.tokenizer.encode(clean_prompt))
results = torch.zeros(n_layers, seq_len)

# Clean run - save activations
with model.trace(clean_prompt):
    clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]

# Patch each layer x position (separate forward passes)
for layer_idx in range(n_layers):
    for pos_idx in range(seq_len):
        with model.trace(corrupted_prompt):
            # Patch only this position
            model.transformer.h[layer_idx].output[0][:, pos_idx, :] = \
                clean_hiddens[layer_idx][:, pos_idx, :]
            patched_logits = model.lm_head.output.save()
        results[layer_idx, pos_idx] = logit_diff(
            patched_logits.value, correct_token, incorrect_token
        )

Attention Head Patching

Patch individual attention heads:

python
n_heads = model.config.n_head
head_dim = model.config.n_embd // n_heads
results = torch.zeros(n_layers, n_heads)

# Clean run - save attention outputs (before projection)
with model.trace(clean_prompt):
    clean_attn = [layer.attn.c_proj.input[0][0].save()
                  for layer in model.transformer.h]

# Patch each layer x head (separate forward passes)
for layer_idx in range(n_layers):
    for head_idx in range(n_heads):
        with model.trace(corrupted_prompt):
            # Patch single head's output
            start = head_idx * head_dim
            end = (head_idx + 1) * head_dim
            model.transformer.h[layer_idx].attn.c_proj.input[0][0][:, :, start:end] = \
                clean_attn[layer_idx][:, :, start:end]
            patched_logits = model.lm_head.output.save()
        results[layer_idx, head_idx] = logit_diff(
            patched_logits.value, correct_token, incorrect_token
        )

Noising (Reverse Patching)

Instead of restoring clean activations, corrupt clean activations:

python
# Corrupted run - save activations
with model.trace(corrupted_prompt):
    corrupted_hiddens = [layer.output[0].save() for layer in model.transformer.h]

# For each layer, inject corrupted activation into clean run
noising_results = torch.zeros(n_layers)
for layer_idx in range(n_layers):
    with model.trace(clean_prompt):
        # Inject corrupted activation into clean run
        model.transformer.h[layer_idx].output[0][:] = corrupted_hiddens[layer_idx]
        noised_logits = model.lm_head.output.save()
    noising_results[layer_idx] = logit_diff(noised_logits.value, correct_token, incorrect_token)

Visualization

python
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12, 8))
sns.heatmap(
    results.numpy(),
    xticklabels=[f"Pos {i}" for i in range(seq_len)],
    yticklabels=[f"Layer {i}" for i in range(n_layers)],
    cmap="RdBu_r",
    center=0,
    annot=True,
    fmt=".2f"
)
plt.title("Activation Patching Results")
plt.xlabel("Token Position")
plt.ylabel("Layer")
plt.tight_layout()
plt.show()

Interpretation

  • High positive values: Component is important for correct behavior
  • Values near 0: Component doesn't affect this behavior
  • Negative values: Component actively pushes toward wrong answer
  • Clusters of importance: Suggest circuits or computational stages

Expand your agent's capabilities with these related and highly-rated skills.

Didn't find tool you were looking for?

Be as detailed as possible for better results