🔍

The Problem

A trained model is a black box — unless you pry it open
act interpretability

You've trained an ACT policy. It picks up cubes, stacks them, maybe even does it reliably. But what is it actually looking at? When the robot arm moves left, is the model looking at the cube? At the gripper? At a shadow on the table?

This matters more than curiosity. If your model is succeeding by looking at the right things, it will generalize. If it's succeeding by coincidence — a shadow that happens to correlate with the cube position in your training data — it will break the moment something changes.

We have two tools to answer this question, and they look at different parts of the ACT architecture:

GradCAM looks at the ResNet backbone (the CNN). It answers: "which pixels in the image caused the strongest feature activations that influenced the action output?"

Attention visualization looks at the transformer decoder. It answers: "when the decoder was deciding what action to take, which spatial positions in the encoder output did it attend to?"

GradCAM and attention maps answer different questions about the same model. GradCAM reveals what the CNN extracts. Attention reveals what the transformer uses. A feature the CNN detects but the transformer ignores will show up in GradCAM but not in attention.
Two visualization modes target different parts of the ACT architecture
🔄

The ACT Pipeline

Quick recap: image → features → tokens → actions
act architecture resnet

Before we can understand the visualization, we need to trace the exact path data takes through the model. Let's follow one camera image from raw pixels to action prediction:

Step 1 — ResNet backbone. Each camera image (say 480×640 RGB) goes through a ResNet-18 CNN. The last convolutional block (layer4) outputs a feature map: a 3D tensor of shape [512, 15, 20]. Think of it as 512 "detector channels" over a 15×20 spatial grid. Each cell in this grid roughly corresponds to a 32×32 patch of the original image.

Step 2 — Flatten to tokens. The feature map is reshaped into 300 tokens (15×20 = 300), each of dimension 512. Positional embeddings are added so the transformer knows where each token came from spatially.

Step 3 — Transformer encoder. The 300 image tokens (per camera), plus the robot state token and style variable, are fed through the transformer encoder's self-attention layers. The output: contextual embeddings.

Step 4 — Transformer decoder. Learned query embeddings (one per action timestep) attend to the encoder output via cross-attention. Each query "asks" the encoder output for the information it needs to predict one action in the chunk.

The ACT data flow: image pixels → CNN feature map → tokens → encoder → decoder → actions

Our two visualization tools hook into different stages of this pipeline. GradCAM hooks into Step 1 (the CNN output). Attention visualization hooks into Step 4 (the decoder's cross-attention). Let's look at each one.

Activations

What the CNN detects at each spatial position
resnet gradcam

The ResNet backbone processes your camera image through layers of convolutions. At each layer, it builds increasingly abstract representations. layer1 detects edges and corners. layer2 detects textures and simple shapes. By the time we reach layer4 (the final block), the network is detecting high-level concepts: object parts, spatial relationships, semantic features.

The output of layer4 is a tensor of shape [1, 512, 15, 20]. The activations are simply the values in this tensor. At spatial position (row 5, col 10), channel 217 might have a value of 3.7 — meaning "feature detector #217 fired strongly at this location."

You can think of each of the 512 channels as a different "question" the CNN is asking about the image. Channel 42 might respond to red-ish objects. Channel 198 might respond to edges of rectangular shapes. Channel 401 might respond to the gripper's silhouette. We don't choose what they detect — the network learns this during training.

ResNet layer4 activations: 512 learned feature detectors over a 15×20 spatial grid

The problem: we have 512 × 15 × 20 = 153,600 activation values. Which of these actually matter for the action the model predicts? That's where gradients come in.

Why do we hook into layer4 specifically, rather than layer1 or layer2?
Earlier layers capture low-level features (edges, textures) that are too generic to be interpretable — almost every part of the image has edges. layer4 captures high-level semantic features (objects, parts) that correspond to meaningful concepts. Its spatial resolution (15×20) is also coarse enough to produce clean heatmaps rather than noisy pixel-level splotches.
📉

Gradients

Which activations the action output is sensitive to
gradcam backprop

Activations tell you what the CNN sees. But the model might see a hundred things and only care about three of them for its action decision. Gradients tell you what the model cares about.

Here's the idea. The model computes: image → activations → ... (many layers of math) ... → action output. That entire chain is differentiable. We can ask PyTorch: "if I slightly increased activation value at position (row 5, col 10, channel 217), how much would the action output change?"

That's exactly what .backward() computes. It traces the chain of math backwards from the action output to the activations and computes the partial derivative of the output with respect to each activation. A large gradient means the output is sensitive to that activation. A near-zero gradient means the model doesn't care about it, no matter how strongly it fired.

Critically, this does not modify the model. No weights are updated. There's no optimizer step. We're just asking a hypothetical: "how sensitive is the output to each internal value?" The model is identical before and after.

In normal inference, gradients are disabled (torch.no_grad()) for speed. Our visualizer intentionally leaves them enabled so the computation graph is built and .backward() can flow through it.

Forward pass builds the computation graph; backward pass traces sensitivity from output to activations
Gradients answer: "if this activation were slightly different, how much would the action change?" Large gradient = high influence. Zero gradient = irrelevant to the decision.
If an activation has a large value but a near-zero gradient, what does that mean?
The CNN strongly detected something at that position (high activation), but the downstream transformer and action head didn't use that information for the action prediction (low gradient). The model sees it but doesn't care about it for this particular decision. For example, the CNN might strongly detect the table surface, but the action doesn't depend on the table — it depends on the cube.
🔥

GradCAM

Activations × Gradients = what drove the decision
gradcam resnet visualization

Now we combine the two pieces. We have activations (shape [1, 512, 15, 20]) and gradients (same shape). GradCAM combines them in three steps:

1. Compute importance weights. For each of the 512 channels, average the gradient values over the entire 15×20 spatial grid. This gives one number per channel: "how important is this feature detector overall?"

2. Weighted sum. Multiply each channel's activation map (15×20) by its importance weight, then sum across all 512 channels. This collapses the 512 channels into a single 15×20 heatmap.

3. ReLU and normalize. Apply ReLU (zero out negatives — we only care about features with positive influence) and normalize to [0, 1].

The result is a 15×20 heatmap. Each cell tells you: "how much did the CNN features at this spatial location drive the action prediction?" We resize this to the original image dimensions and overlay it using a jet colormap.

GradCAM: per-channel importance weights × activations → spatial heatmap
A hot spot on the GradCAM heatmap means: "features at this image region were both strongly detected AND strongly influential on the action." Not just seen — used.
GradCAM core computation
act = hooks.activations[i] # [1, 512, 15, 20] grad = hooks.gradients[i] # [1, 512, 15, 20] # Step 1: importance weight per channel weights = grad.mean(dim=(2, 3), keepdim=True) # [1, 512, 1, 1] # Step 2: weighted sum across channels cam = torch.relu((weights * act).sum(dim=1, keepdim=True)) # [1, 1, 15, 20] # Step 3: normalize cam = cam.squeeze().cpu().numpy() # [15, 20] if cam.max() > 0: cam = cam / cam.max() # [0, 1]
Why do we average the gradients over the spatial dimensions (step 1) rather than using them pixel-by-pixel?
Averaging produces one importance score per channel, which acts as a "global" signal for how important that entire feature detector is. Using per-pixel gradients would create a noisier, harder-to-interpret map. The spatial information already comes from the activations themselves — the gradient averaging just tells us which channels matter, and the activation values tell us where those channels fire.
🪝

The Hook Trick

Intercepting internal tensors without modifying the model
pytorch gradcam

There's a practical problem: the activations at layer4 are intermediate values. They exist briefly during the forward pass and are then consumed by the next layer. PyTorch doesn't save them by default. So how do we capture them?

Hooks. PyTorch lets you attach callback functions to any layer. A forward hook fires after a layer's forward pass and receives its output. A backward hook fires during .backward() and receives the gradients flowing through that layer.

Attaching hooks to capture activations and gradients
def fwd(_mod, _inp, out): # fires after layer4.forward() act_list.append(out.detach()) # save the feature map def bwd(_mod, _gi, go): # fires during .backward() grad_list.insert(0, go[0].detach()) # save gradients (reverse order!) fh = target_layer.register_forward_hook(fwd) bh = target_layer.register_full_backward_hook(bwd)

There's a subtle gotcha with multiple cameras. The ACT model processes each camera image through the backbone in a loop:

Inside ACT's forward pass
for img in batch["observation.images"]: cam_features = self.backbone(img)["feature_map"] # hook fires here!

The hook fires once per camera. If you only store a single value (overwriting each time), you get the last camera's data for all cameras — a real bug we hit. The fix: append to a list.

Backward hooks fire in reverse order. If the forward pass processes [front_cam, wrist_cam], the backward pass processes gradients in reverse: [wrist_cam, front_cam]. So we use insert(0, ...) (prepend) instead of append to keep the lists aligned.

Hooks fire per camera: forward in order, backward in reverse — lists must stay aligned
Forward hooks fire in the order the backbone processes cameras. Backward hooks fire in reverse. append for forward, insert(0, ...) for backward keeps them paired correctly.
👁

Attention Maps

What the transformer decoder attends to
transformer attention cross-attention

Attention visualization works completely differently from GradCAM. No backward pass, no gradients. Instead, we look at the attention weights the transformer already computes during its normal forward pass.

In the ACT decoder, each layer has a cross-attention sublayer. This is where the decoder queries (one per action timestep) attend to the encoder output. The attention weights are a matrix that says: "for each decoder query, how much did it attend to each encoder token?"

The encoder output is a sequence of tokens. The token layout looks like this:

Encoder output token layout: [latent, (robot_state), *cam0_tokens, *cam1_tokens, ...]

The image tokens from each camera are flattened feature maps. Camera 0 contributes 15×20 = 300 tokens, camera 1 contributes another 300. We know the spatial shape of each camera's feature map from the backbone, so we can reshape the attention weights back into 2D spatial maps.

The steps are:

1. Hook into the last decoder layer's multihead_attn to capture the cross-attention weight matrix.

2. Average over attention heads (each head attends to different things — averaging gives the overall picture).

3. Average over decoder query positions (all action timesteps combined — "what does the model attend to overall?").

4. Slice out each camera's tokens and reshape to (H, W) spatial maps.

5. Normalize globally across all cameras so they're comparable.

Cross-attention weights → per-camera spatial maps via token slicing and reshaping
Attention maps show what the transformer queries are looking at. Unlike GradCAM, this requires no backward pass — attention weights are a natural byproduct of the forward pass. We just need to grab them.
Why do we look at the last decoder layer's cross-attention, rather than the first layer or the encoder's self-attention?
The last decoder layer produces the final representation that directly drives the action prediction, so its attention pattern is the most decision-relevant. Earlier layers might attend broadly while still gathering context. The encoder's self-attention shows how image tokens relate to each other, not how the action prediction uses them.
🦾

Proprioception

How much does the model rely on joint state vs. vision?
attention proprioception

Among the encoder tokens, one is special: the robot state token (proprioception). This is the linear projection of the joint positions/velocities. When the decoder cross-attends to the encoder output, some attention goes to image tokens and some goes to this state token.

We extract the attention weight on the proprioception token separately. This gives a single number (0 to 1 after normalization) that says: "what fraction of the decoder's attention went to joint state vs. visual information?"

This is visualized as a colored border around the overlay image. Bright magenta border = the model is leaning heavily on proprioception for this frame. Dim or no border = visual features dominate the decision.

Proprioception attention value rendered as border intensity: bright = relying on joint state, dim = relying on vision

This can reveal interesting dynamics over an episode. You might see the model rely on vision during the approach phase (need to locate the cube), then switch to proprioception during grasping (need to feel whether the gripper has closed), then back to vision for placement.

If the proprioception border is consistently bright across all frames of an episode, what might that indicate?
The model is primarily relying on joint state rather than visual input to make decisions. This could mean: (1) the task can be solved largely through proprioception (e.g., a fixed pick location), (2) the visual features aren't discriminative enough for the model to rely on, or (3) the cameras aren't providing useful viewpoints. It's a signal to investigate whether the model would generalize to different object positions.

GradCAM vs. Attention

Two lenses on the same model — when to use which
gradcam attention comparison
GradCAM vs. Attention: different questions, different answers, complementary insights

GradCAM tells you what the CNN backbone extracts that matters for the action. It's backward-pass-based, so it shows true causal influence — "changing these pixels would change the action." But it only sees the backbone. If the transformer ignores a strongly-activated feature, GradCAM still highlights it.

Attention maps tell you what the transformer decoder queries for. It's forward-pass-only, so it's fast and doesn't require gradient computation. But attention weights are not necessarily causal — a token can have high attention weight without actually influencing the output (if the value vectors carry redundant information).

When they agree: both highlight the same image region → strong evidence that region drives the action.

When they disagree: interesting. GradCAM hot but attention cold means the CNN extracts a feature there that influences the action through some path the attention map doesn't capture (e.g., through the robot state token or encoder self-attention). Attention hot but GradCAM cold means the decoder looks there, but the gradient signal is diluted across many channels.

Run both modes on the same episode and compare. Agreement means confidence. Disagreement means something interesting is happening in the model's internal processing.
🏁

Recap

The full picture in one place
summary

Activations = what the CNN detects. 512 feature detectors fire at different spatial locations. This is the raw output of the ResNet backbone.

Gradients = what the action output is sensitive to. Computed by backpropagation. Large gradient means that activation influenced the action.

GradCAM = activations × gradients, collapsed to a spatial heatmap. Shows which image regions drove the action through the CNN pathway. Requires a backward pass but doesn't modify the model.

Attention maps = cross-attention weights from the decoder's last layer, reshaped back to spatial maps. Shows what the transformer queries for when making action predictions. Forward-pass only.

Proprioception attention = how much the decoder attends to the robot state token vs. image tokens. Visualized as border intensity.

Hooks = PyTorch callbacks that intercept intermediate tensors. Forward hooks capture activations (one per camera, in order). Backward hooks capture gradients (one per camera, in reverse order — use insert(0, ...)).

Running the visualizer
# GradCAM mode (what CNN features drive actions) uv run python operation_utils/gradcam_policy_visualizer.py \ --model Servo7/your-model --dataset Servo7/your-dataset \ --episode 0 --mode gradcam # Attention mode (what the transformer decoder attends to) uv run python operation_utils/gradcam_policy_visualizer.py \ --model Servo7/your-model --dataset Servo7/your-dataset \ --episode 0 --mode attention

Further reading:

  • Grad-CAM: Visual Explanations from Deep Networks (Selvaraju et al., 2017)
  • Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (Zhao et al., 2023) — ACT
  • villekuosmanen/physical-AI-interpretability — attention visualization for ACT