"""PyTorch implementation of SmoothGrad."""
import torch
import torch.nn as nn
import numpy as np
from typing import Tuple, List, Union, Optional
[docs]
class SmoothGrad:
"""SmoothGrad attribution method.
Implements SmoothGrad as described in the original paper:
"SmoothGrad: removing noise by adding noise"
https://arxiv.org/abs/1706.03825
"""
[docs]
def __init__(self, model, num_samples=16, noise_scale=1.0):
"""Initialize SmoothGrad.
Args:
model: PyTorch model
num_samples: Number of noisy samples to use (matches TF default 16)
noise_scale: Standard deviation of noise to add (matches TF behavior, default 1.0)
"""
self.model = model
self.num_samples = num_samples
self.noise_scale = noise_scale
[docs]
def attribute(self, inputs, target=None, num_samples=None, noise_scale=None):
"""Calculate SmoothGrad attribution.
Args:
inputs: Input tensor
target: Target class index (None for argmax)
num_samples: Override the number of samples (optional)
noise_scale: Override the noise scale (optional)
Returns:
Attribution tensor of the same shape as inputs
"""
# Get parameters (use instance defaults if not provided)
num_samples = num_samples if num_samples is not None else self.num_samples
noise_scale = noise_scale if noise_scale is not None else self.noise_scale
# Ensure input is a tensor
if not isinstance(inputs, torch.Tensor):
inputs = torch.tensor(inputs, dtype=torch.float32)
# Clone inputs to avoid modifying the original
inputs = inputs.clone().detach()
# Use fixed noise standard deviation (matches TensorFlow behavior)
noise_std = noise_scale
# Store original model mode
original_mode = self.model.training
self.model.eval()
# Accumulate gradients
accumulated_gradients = torch.zeros_like(inputs)
for i in range(num_samples):
# Generate noisy input (matches TensorFlow's np.random.normal behavior)
noise = torch.normal(0, noise_std, size=inputs.shape, device=inputs.device)
noisy_input = inputs + noise
noisy_input.requires_grad_(True)
# Forward pass
self.model.zero_grad()
output = self.model(noisy_input)
# Determine target classes
if target is None:
target_indices = output.argmax(dim=1)
elif isinstance(target, int):
target_indices = torch.full((inputs.shape[0],), target, dtype=torch.long, device=inputs.device)
elif isinstance(target, torch.Tensor):
if target.numel() == 1: # Single class for all examples
target_indices = torch.full((inputs.shape[0],), target.item(), dtype=torch.long, device=inputs.device)
else: # Different target for each example
target_indices = target
else:
raise ValueError(f"Unsupported target type: {type(target)}")
# One-hot encoding for target classes
one_hot = torch.zeros_like(output)
one_hot.scatter_(1, target_indices.view(-1, 1), 1.0)
# Backward pass
output.backward(gradient=one_hot)
# Accumulate gradients
if noisy_input.grad is not None:
accumulated_gradients += noisy_input.grad
# Restore model mode
self.model.train(original_mode)
# Average gradients
smoothgrad_attribution = accumulated_gradients / num_samples
# Apply small value thresholding for numerical stability
smoothgrad_attribution[torch.abs(smoothgrad_attribution) < 1e-10] = 0.0
return smoothgrad_attribution
[docs]
class SmoothGradXSign(SmoothGrad):
"""SmoothGrad × Sign attribution method.
Implements SmoothGrad multiplied by the sign of (input - threshold),
which can emphasize both positive and negative contributions.
"""
[docs]
def __init__(self, model, num_samples=16, noise_scale=1.0, mu=0.0):
"""Initialize SmoothGradXSign.
Args:
model: PyTorch model
num_samples: Number of noisy samples to use (matches TF default 16)
noise_scale: Standard deviation of noise to add (matches TF behavior, default 1.0)
mu: Threshold value for the sign function
"""
super().__init__(model, num_samples, noise_scale)
self.mu = mu
[docs]
def attribute(self, inputs, target=None, num_samples=None, noise_scale=None, mu=None):
"""Calculate SmoothGrad × Sign attribution.
Args:
inputs: Input tensor
target: Target class index (None for argmax)
num_samples: Override the number of samples (optional)
noise_scale: Override the noise scale (optional)
mu: Override the threshold value (optional)
Returns:
Attribution tensor of the same shape as inputs
"""
# Get smooth gradients
smooth_gradients = super().attribute(inputs, target, num_samples, noise_scale)
# Ensure input is a tensor
if not isinstance(inputs, torch.Tensor):
inputs = torch.tensor(inputs, dtype=torch.float32)
# Get threshold value (use instance default if not provided)
mu_value = mu if mu is not None else self.mu
# Calculate sign of (input - threshold)
input_sign = torch.sign(inputs.clone().detach() - mu_value)
# Multiply by the sign (element-wise)
attribution = smooth_gradients * input_sign
return attribution
[docs]
def smoothgrad(model, inputs, target=None, num_samples=16, noise_scale=1.0):
"""Calculate SmoothGrad attribution (functional API).
Args:
model: PyTorch model
inputs: Input tensor
target: Target class index (None for argmax)
num_samples: Number of noisy samples to use
noise_scale: Standard deviation of noise to add
Returns:
Attribution tensor of the same shape as inputs
"""
# Create SmoothGrad instance and calculate attribution
return SmoothGrad(model, num_samples, noise_scale).attribute(inputs, target)
[docs]
def smoothgrad_x_sign(model, inputs, target=None, num_samples=16, noise_scale=1.0, mu=0.0):
"""Calculate SmoothGrad × Sign attribution (functional API).
Args:
model: PyTorch model
inputs: Input tensor
target: Target class index (None for argmax)
num_samples: Number of noisy samples to use
noise_scale: Standard deviation of noise to add
mu: Threshold value for the sign function
Returns:
Attribution tensor of the same shape as inputs
"""
# Create SmoothGradXSign instance and calculate attribution
return SmoothGradXSign(model, num_samples, noise_scale, mu).attribute(inputs, target)