"""Unified PyTorch implementation of Grad-CAM combining the best features from both implementations."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Union, Optional, Tuple, List
[docs]
class GradCAM:
"""Unified Grad-CAM implementation for PyTorch models.
Combines the automatic layer detection from gradcam.py with the
TensorFlow-compatible behavior from grad_cam.py.
Grad-CAM uses the gradients of a target concept flowing into the final
convolutional layer to produce a coarse localization map highlighting
important regions in the image for prediction.
"""
[docs]
def __init__(self, model, target_layer=None):
"""Initialize GradCAM.
Args:
model: PyTorch model
target_layer: Target layer for Grad-CAM. If None, will try to
automatically find the last convolutional layer.
"""
# Wrap model to avoid inplace operations if needed
self.model = self._wrap_model_no_inplace(model)
self.original_model = model # Keep reference to original
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.hooks = []
# If target_layer is not provided, try to find the last convolutional layer
if self.target_layer is None:
self.target_layer = self._find_target_layer(self.model)
# Check if target_layer was found or provided
if self.target_layer is None:
raise ValueError("Could not automatically identify a target convolutional layer. "
"Please specify one explicitly.")
# Register hooks
self._register_hooks()
def _wrap_model_no_inplace(self, model):
"""Wrap model to replace inplace operations that can cause issues with hooks."""
import copy
class NoInplaceWrapper(nn.Module):
def __init__(self, model):
super().__init__()
# Use deepcopy to avoid modifying the original model
self.model = copy.deepcopy(model)
self._replace_inplace_relu(self.model)
def _replace_inplace_relu(self, module):
"""Recursively replace inplace ReLUs."""
for name, child in module.named_children():
if isinstance(child, nn.ReLU) and child.inplace:
setattr(module, name, nn.ReLU(inplace=False))
else:
self._replace_inplace_relu(child)
def forward(self, x):
return self.model(x)
def __getattr__(self, name):
# Delegate attribute access to the wrapped model
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)
# Check if model has inplace operations
has_inplace = False
def check_inplace(module):
nonlocal has_inplace
for child in module.children():
if isinstance(child, nn.ReLU) and child.inplace:
has_inplace = True
return
check_inplace(child)
check_inplace(model)
# Only wrap if there are inplace operations
if has_inplace:
return NoInplaceWrapper(model)
return model
def _find_target_layer(self, model):
"""Find the last convolutional layer in the model.
This method searches for Conv2d (images) and Conv1d (time series) layers.
"""
target_layer = None
# Special handling for known architectures
if hasattr(model, 'layer4'):
# ResNet-like models
return model.layer4[-1].conv2
elif hasattr(model, 'features'):
# VGG-like models
for i in range(len(model.features) - 1, -1, -1):
if isinstance(model.features[i], (nn.Conv2d, nn.Conv1d)):
return model.features[i]
# Generic search for the last conv layer
last_conv = None
def search_conv(module):
nonlocal last_conv
for m in module.children():
if len(list(m.children())) > 0:
# Recurse into submodules
search_conv(m)
elif isinstance(m, (nn.Conv2d, nn.Conv1d)):
last_conv = m
search_conv(model)
return last_conv
def _register_hooks(self):
"""Register forward and backward hooks."""
# Clear any existing hooks
for hook in self.hooks:
hook.remove()
self.hooks = []
def forward_hook(module, input, output):
# Clone to avoid inplace modification issues
self.activations = output.clone().detach()
def backward_hook(module, grad_input, grad_output):
# Clone to avoid inplace modification issues
self.gradients = grad_output[0].clone().detach()
# Register hooks
forward_handle = self.target_layer.register_forward_hook(forward_hook)
# Use register_full_backward_hook for newer PyTorch, fallback for older versions
try:
backward_handle = self.target_layer.register_full_backward_hook(backward_hook)
except AttributeError:
backward_handle = self.target_layer.register_backward_hook(backward_hook)
self.hooks.extend([forward_handle, backward_handle])
[docs]
def forward(self, x, target_class=None):
"""Generate Grad-CAM attribution map using the TensorFlow-compatible approach.
Args:
x: Input tensor
target_class: Target class index (None for argmax)
Returns:
Grad-CAM attribution map
"""
# Set model to eval mode
original_mode = self.model.training
self.model.eval()
# Clone input to avoid modifying the original
x = x.clone().detach().requires_grad_(True)
# Reset stored activations and gradients
self.activations = None
self.gradients = None
# Forward pass
self.model.zero_grad()
output = self.model(x)
# Select target class
if target_class is None:
target_class = output.argmax(dim=1)
elif isinstance(target_class, (int, np.integer)):
target_class = torch.tensor([int(target_class)], device=output.device)
elif isinstance(target_class, (list, tuple)):
target_class = torch.tensor(target_class, device=output.device)
elif isinstance(target_class, np.ndarray):
target_class = torch.from_numpy(target_class).to(output.device)
# Create one-hot encoding
if output.dim() == 2: # Batch output
one_hot = torch.zeros_like(output)
if target_class.dim() == 0: # Single value
target_class = target_class.unsqueeze(0)
one_hot.scatter_(1, target_class.unsqueeze(1), 1.0)
else: # Single output
one_hot = torch.zeros_like(output)
if target_class.dim() > 0:
target_class = target_class[0]
one_hot[target_class] = 1.0
# Backward pass
output.backward(gradient=one_hot, retain_graph=True)
# Ensure we have activations and gradients
if self.activations is None or self.gradients is None:
raise ValueError("Could not capture activations or gradients. "
"Check that the target layer is correct.")
# Calculate weights - match TensorFlow's reduce_mean behavior
if self.gradients.dim() == 4: # For images (B, C, H, W)
weights = torch.mean(self.gradients, dim=(0, 2, 3), keepdim=False)
else: # For time series (B, C, T)
weights = torch.mean(self.gradients, dim=(0, 2), keepdim=False)
# Extract first sample's activations (match TensorFlow behavior)
activations = self.activations[0] # Remove batch dimension
# Weight activations by importance
if activations.dim() == 3: # (C, H, W)
weighted_output = activations * weights[:, None, None]
else: # (C, T)
weighted_output = activations * weights[:, None]
# Sum across feature map channels
cam = torch.sum(weighted_output, dim=0, keepdim=False)
# Apply ReLU and normalize
cam = F.relu(cam)
# TensorFlow-style normalization
epsilon = 1e-7
cam = cam / (torch.max(cam) + epsilon)
# Don't resize here - we'll resize in calculate_grad_cam_relevancemap
# This avoids double resizing and keeps the raw CAM output
pass
# Clean up hooks
for hook in self.hooks:
hook.remove()
self.hooks = []
# Restore model mode
self.model.train(original_mode)
return cam
[docs]
def attribute(self, inputs, target=None, resize_to_input=True):
"""Generate Grad-CAM heatmap (compatible with gradcam.py interface).
Args:
inputs: Input tensor
target: Target class index (None for argmax)
resize_to_input: Whether to resize heatmap to input size
Returns:
Grad-CAM heatmap (same size as input if resize_to_input=True)
"""
# Handle tensor conversion
if not isinstance(inputs, torch.Tensor):
inputs = torch.tensor(inputs, dtype=torch.float32)
# Use the forward method
cam = self.forward(inputs, target_class=target)
# Handle batch dimension for compatibility
if inputs.dim() == 4 and cam.dim() == 2: # Batch input, single CAM
cam = cam.unsqueeze(0).unsqueeze(0) # Add batch and channel dims
if inputs.shape[0] > 1:
# Repeat for batch size
cam = cam.repeat(inputs.shape[0], 1, 1, 1)
return cam
[docs]
def calculate_grad_cam_relevancemap(model, input_tensor, target_layer=None, target_class=None, layer_name=None, **kwargs):
"""Calculate Grad-CAM relevance map for images.
This function provides a convenient interface compatible with grad_cam.py.
Args:
model: PyTorch model
input_tensor: Input tensor
target_layer: Target layer for Grad-CAM (None to auto-detect)
target_class: Target class index (None for argmax)
layer_name: Alternative name for target_layer (for compatibility)
**kwargs: Additional parameters (ignored)
Returns:
Grad-CAM relevance map as numpy array
"""
# Use layer_name if target_layer not provided (for compatibility)
if target_layer is None and layer_name is not None:
# Get the actual layer from the model using the layer name
# Handle both dot notation and indexed access
parts = layer_name.split('.')
target_layer = model
for part in parts:
if part.isdigit():
# Numeric index
target_layer = target_layer[int(part)]
else:
# Attribute access
target_layer = getattr(target_layer, part)
# Initialize Grad-CAM
grad_cam = GradCAM(model, target_layer)
# Generate attribution map
with torch.enable_grad():
cam = grad_cam.forward(input_tensor, target_class)
# Convert to numpy and handle dimensions
if isinstance(cam, torch.Tensor):
cam = cam.detach().cpu().numpy()
# Resize CAM to input size for consistency with TensorFlow
input_shape = input_tensor.shape
if input_tensor.dim() == 4: # Batch input (B, C, H, W)
target_size = (input_shape[2], input_shape[3]) # (H, W)
else: # Single input (C, H, W)
target_size = (input_shape[1], input_shape[2]) # (H, W)
# Resize if necessary
if cam.shape[-2:] != target_size:
try:
import cv2
# cv2 expects (W, H) for resize
cv2_size = (target_size[1], target_size[0])
if cam.ndim == 2:
cam = cv2.resize(cam, cv2_size)
elif cam.ndim == 3: # Batch of CAMs
resized_cams = []
for i in range(cam.shape[0]):
resized_cams.append(cv2.resize(cam[i], cv2_size))
cam = np.stack(resized_cams, axis=0)
except ImportError:
# Fallback to scipy if cv2 is not available
from scipy.ndimage import zoom
if cam.ndim == 2:
zoom_factors = (target_size[0] / cam.shape[0], target_size[1] / cam.shape[1])
cam = zoom(cam, zoom_factors, order=1)
elif cam.ndim == 3:
zoom_factors = (1, target_size[0] / cam.shape[1], target_size[1] / cam.shape[2])
cam = zoom(cam, zoom_factors, order=1)
# Handle batch dimension if present
if hasattr(input_tensor, 'dim') and input_tensor.dim() == 4: # Batch
# Return with batch dimension
if cam.ndim == 2: # Single CAM without batch
cam = np.expand_dims(cam, axis=0)
else: # Single input
# Remove any extra dimensions if input is single
if cam.ndim == 3 and cam.shape[0] == 1:
cam = cam[0]
return cam
[docs]
def calculate_grad_cam_relevancemap_timeseries(model, input_tensor, target_layer=None, target_class=None):
"""Calculate Grad-CAM relevance map for time series data.
This function provides compatibility with grad_cam.py's timeseries function.
Args:
model: PyTorch model
input_tensor: Input tensor (B, C, T)
target_layer: Target layer for Grad-CAM (None to auto-detect)
target_class: Target class index (None for argmax)
Returns:
Grad-CAM relevance map as numpy array
"""
# Find the last conv1d layer if not specified
if target_layer is None:
for module in reversed(list(model.modules())):
if isinstance(module, nn.Conv1d):
target_layer = module
break
if target_layer is None:
raise ValueError("Could not find Conv1d layer for time series Grad-CAM")
# Use the unified implementation
return calculate_grad_cam_relevancemap(model, input_tensor, target_layer, target_class)
# Aliases for backward compatibility
find_target_layer = lambda model: GradCAM(model)._find_target_layer(model)