Agent skill
activation-patching
Causal intervention via activation patching to identify important model components. Use when determining which layers, heads, or positions are causally responsible for model behavior.
Install this agent skill to your Project
npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/data/activation-patching
SKILL.md
Activation Patching
Activation patching is a causal intervention technique that identifies which model components are responsible for specific behaviors by swapping activations between different inputs.
Core Concept
- Clean run: Run model on prompt that produces desired behavior
- Corrupted run: Run on modified prompt that changes the behavior
- Patch: Replace corrupted activations with clean ones, measure if behavior is restored
If patching a component restores the clean behavior, that component is causally important.
Basic Setup
from nnsight import LanguageModel
import torch
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
# Indirect Object Identification (IOI) task
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
# Target tokens
correct_token = model.tokenizer(" John")["input_ids"][0] # Clean answer
incorrect_token = model.tokenizer(" Mary")["input_ids"][0] # Corrupted answer
Metric: Logit Difference
def logit_diff(logits, correct_idx, incorrect_idx):
"""Measure how much model prefers correct over incorrect token."""
return (logits[0, -1, correct_idx] - logits[0, -1, incorrect_idx]).item()
Three-Run Patching Pattern
n_layers = len(model.transformer.h)
results = torch.zeros(n_layers)
# Run 1: Clean - save activations
with model.trace(clean_prompt):
clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]
clean_logits = model.lm_head.output.save()
# Run 2: Corrupted baseline
with model.trace(corrupted_prompt):
corrupted_logits = model.lm_head.output.save()
# Runs 3+: Patch each layer (separate forward passes)
for layer_idx in range(n_layers):
with model.trace(corrupted_prompt):
# Replace corrupted activation with clean
model.transformer.h[layer_idx].output[0][:] = clean_hiddens[layer_idx]
patched_logits = model.lm_head.output.save()
results[layer_idx] = logit_diff(patched_logits.value, correct_token, incorrect_token)
# Normalize results
clean_diff = logit_diff(clean_logits.value, correct_token, incorrect_token)
corrupted_diff = logit_diff(corrupted_logits.value, correct_token, incorrect_token)
normalized = (results - corrupted_diff) / (clean_diff - corrupted_diff)
Position-Specific Patching
Patch only specific token positions:
seq_len = len(model.tokenizer.encode(clean_prompt))
results = torch.zeros(n_layers, seq_len)
# Clean run - save activations
with model.trace(clean_prompt):
clean_hiddens = [layer.output[0].save() for layer in model.transformer.h]
# Patch each layer x position (separate forward passes)
for layer_idx in range(n_layers):
for pos_idx in range(seq_len):
with model.trace(corrupted_prompt):
# Patch only this position
model.transformer.h[layer_idx].output[0][:, pos_idx, :] = \
clean_hiddens[layer_idx][:, pos_idx, :]
patched_logits = model.lm_head.output.save()
results[layer_idx, pos_idx] = logit_diff(
patched_logits.value, correct_token, incorrect_token
)
Attention Head Patching
Patch individual attention heads:
n_heads = model.config.n_head
head_dim = model.config.n_embd // n_heads
results = torch.zeros(n_layers, n_heads)
# Clean run - save attention outputs (before projection)
with model.trace(clean_prompt):
clean_attn = [layer.attn.c_proj.input[0][0].save()
for layer in model.transformer.h]
# Patch each layer x head (separate forward passes)
for layer_idx in range(n_layers):
for head_idx in range(n_heads):
with model.trace(corrupted_prompt):
# Patch single head's output
start = head_idx * head_dim
end = (head_idx + 1) * head_dim
model.transformer.h[layer_idx].attn.c_proj.input[0][0][:, :, start:end] = \
clean_attn[layer_idx][:, :, start:end]
patched_logits = model.lm_head.output.save()
results[layer_idx, head_idx] = logit_diff(
patched_logits.value, correct_token, incorrect_token
)
Noising (Reverse Patching)
Instead of restoring clean activations, corrupt clean activations:
# Corrupted run - save activations
with model.trace(corrupted_prompt):
corrupted_hiddens = [layer.output[0].save() for layer in model.transformer.h]
# For each layer, inject corrupted activation into clean run
noising_results = torch.zeros(n_layers)
for layer_idx in range(n_layers):
with model.trace(clean_prompt):
# Inject corrupted activation into clean run
model.transformer.h[layer_idx].output[0][:] = corrupted_hiddens[layer_idx]
noised_logits = model.lm_head.output.save()
noising_results[layer_idx] = logit_diff(noised_logits.value, correct_token, incorrect_token)
Visualization
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(12, 8))
sns.heatmap(
results.numpy(),
xticklabels=[f"Pos {i}" for i in range(seq_len)],
yticklabels=[f"Layer {i}" for i in range(n_layers)],
cmap="RdBu_r",
center=0,
annot=True,
fmt=".2f"
)
plt.title("Activation Patching Results")
plt.xlabel("Token Position")
plt.ylabel("Layer")
plt.tight_layout()
plt.show()
Interpretation
- High positive values: Component is important for correct behavior
- Values near 0: Component doesn't affect this behavior
- Negative values: Component actively pushes toward wrong answer
- Clusters of importance: Suggest circuits or computational stages
Recommended Agent Skills
Expand your agent's capabilities with these related and highly-rated skills.
agent-ops-spec
Manage specification documents in .agent/specs/. Use when user provides requirements, acceptance criteria, or feature descriptions that need to be tracked and validated against implementation.
agent-ops-state
Maintain .agent state files. Use at session start, after meaningful steps, and before concluding: read/update constitution/memory/focus/issues/baseline consistently.
agent-ops-spec
Manage specification documents in .agent/specs/. Use when user provides requirements, acceptance criteria, or feature descriptions that need to be tracked and validated against implementation.
agent-ops-testing
Test strategy, execution, and coverage analysis. Use when designing tests, running test suites, or analyzing test results beyond baseline checks.
agent-ops-testing
Test strategy, execution, and coverage analysis. Use when designing tests, running test suites, or analyzing test results beyond baseline checks.
agent-ops-state
Maintain .agent state files. Use at session start, after meaningful steps, and before concluding: read/update constitution/memory/focus/issues/baseline consistently.
Didn't find tool you were looking for?