Agent skill

logit-lens

Decode intermediate layer predictions using the Logit Lens technique. Use when analyzing what a model predicts at each layer, understanding information flow, or visualizing layer-wise processing.

Stars 163
Forks 31

Install this agent skill to your Project

npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/data/logit-lens

SKILL.md

Logit Lens

Logit Lens decodes intermediate layer activations into vocabulary predictions, revealing what the model "thinks" at each processing step rather than just the final output.

Concept

Transformer language models build predictions incrementally across layers. By applying the final layer norm and unembedding head to intermediate hidden states, we can see evolving predictions.

Basic Implementation

python
from nnsight import LanguageModel
import torch

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

prompt = "The Eiffel Tower is in the city of"
layers = model.transformer.h
probs_layers = []

with model.trace(prompt):
    for layer_idx, layer in enumerate(layers):
        # Get layer output, apply final layer norm, then lm_head
        hidden = layer.output[0]
        normed = model.transformer.ln_f(hidden)
        logits = model.lm_head(normed)

        # Convert to probabilities
        probs = torch.nn.functional.softmax(logits, dim=-1).save()
        probs_layers.append(probs)

Extract Top Predictions

python
# Stack all layer probabilities
all_probs = torch.stack([p.value for p in probs_layers])  # [n_layers, batch, seq, vocab]

# Get top prediction at each layer for final token
final_token_probs = all_probs[:, 0, -1, :]  # [n_layers, vocab]
top_probs, top_tokens = final_token_probs.max(dim=-1)

# Decode predictions
for layer_idx, (prob, token) in enumerate(zip(top_probs, top_tokens)):
    word = model.tokenizer.decode(token.item())
    print(f"Layer {layer_idx}: '{word}' (prob: {prob:.3f})")

Full Sequence Visualization

python
import numpy as np

# Get predictions for all positions
max_probs, tokens = all_probs[:, 0, :, :].max(dim=-1)  # [n_layers, seq_len]

# Decode to words
words = [[model.tokenizer.decode(t.item()) for t in layer_tokens]
         for layer_tokens in tokens]

# Create visualization matrix
input_tokens = model.tokenizer.encode(prompt)
input_words = [model.tokenizer.decode(t) for t in input_tokens]

print("Position:", input_words)
for layer_idx, layer_words in enumerate(words):
    print(f"Layer {layer_idx:2d}:", layer_words)

Efficient Batched Version

For analyzing multiple prompts or comparing behaviors:

python
prompts = [
    "The capital of France is",
    "The capital of Germany is",
    "The capital of Japan is"
]

all_results = []

with model.trace() as tracer:
    for prompt in prompts:
        with tracer.invoke(prompt):
            prompt_probs = []
            for layer in model.transformer.h:
                hidden = layer.output[0]
                logits = model.lm_head(model.transformer.ln_f(hidden))
                probs = torch.nn.functional.softmax(logits[:, -1, :], dim=-1).save()
                prompt_probs.append(probs)
            all_results.append(prompt_probs)

Remote Execution for Large Models

python
from nnsight import CONFIG
CONFIG.set_default_api_key("YOUR_API_KEY")

model = LanguageModel("meta-llama/Llama-3.1-70B")

with model.trace("The meaning of life is", remote=True):
    layer_probs = []
    for layer in model.model.layers:
        hidden = layer.output[0]
        normed = model.model.norm(hidden)
        logits = model.lm_head(normed)
        probs = torch.nn.functional.softmax(logits[:, -1, :], dim=-1).save()
        layer_probs.append(probs)

Interpretation Tips

  • Early layers: Often predict generic/common tokens
  • Middle layers: Begin forming task-relevant predictions
  • Late layers: Converge to final prediction
  • Sudden changes: May indicate important computation happening at that layer
  • Persistent wrong predictions: Suggests information not yet integrated

Visualization with Plotly

python
import plotly.graph_objects as go

fig = go.Figure(data=go.Heatmap(
    z=max_probs.numpy(),
    x=input_words,
    y=[f"Layer {i}" for i in range(len(layers))],
    colorscale="Blues",
    text=words,
    texttemplate="%{text}",
    textfont={"size": 10},
))

fig.update_layout(
    title="Logit Lens: Layer-wise Predictions",
    xaxis_title="Input Position",
    yaxis_title="Layer"
)
fig.show()

Use Cases

  1. Debugging model behavior: See where predictions go wrong
  2. Understanding factual recall: When does the model "know" the answer?
  3. Comparing model architectures: Different models show different patterns
  4. Identifying critical layers: Which layers matter most for a task?

Didn't find tool you were looking for?

Be as detailed as possible for better results