Source code for signxai.torch_signxai.methods_impl.guided

"""PyTorch implementation of Guided Backpropagation and DeconvNet methods."""
import torch
import torch.nn as nn
import numpy as np
from typing import Tuple, List, Union, Optional, Callable, Dict, Any


[docs] class GuidedBackpropReLU(torch.autograd.Function): """Guided Backpropagation ReLU activation. This modified ReLU only passes positive gradients during backpropagation. It combines the backpropagation rules of DeconvNet and vanilla backpropagation. The TensorFlow implementation is: @tf.custom_gradient def guidedRelu(x): def grad(dy): return tf.cast(dy > 0, tf.float32) * tf.cast(x > 0, tf.float32) * dy return tf.nn.relu(x), grad """
[docs] @staticmethod def forward(ctx, input_tensor): ctx.save_for_backward(input_tensor) return torch.nn.functional.relu(input_tensor)
[docs] @staticmethod def backward(ctx, grad_output): input_tensor, = ctx.saved_tensors # Only pass positive gradients and only for positive inputs # This exactly matches the TensorFlow implementation: # tf.cast(dy > 0, tf.float32) * tf.cast(x > 0, tf.float32) * dy positive_grad_mask = (grad_output > 0).float() positive_input_mask = (input_tensor > 0).float() grad_input = positive_grad_mask * positive_input_mask * grad_output return grad_input
[docs] class GuidedBackpropReLUModule(nn.Module): """Module wrapper for the GuidedBackpropReLU function."""
[docs] def forward(self, x): return GuidedBackpropReLU.apply(x)
[docs] def replace_relu_with_guided_relu(model): """Replace all ReLU activations with GuidedBackpropReLU. Args: model: PyTorch model Returns: Modified model with guided ReLU activations """ for name, module in model.named_children(): if isinstance(module, nn.ReLU): setattr(model, name, GuidedBackpropReLUModule()) else: replace_relu_with_guided_relu(module) return model
[docs] def build_guided_model(model): """Build a guided backpropagation model by replacing ReLU activations. Args: model: PyTorch model Returns: Guided model for backpropagation """ # Create a copy of the model to avoid modifying the original try: guided_model = type(model)() guided_model.load_state_dict(model.state_dict()) except: # For more complex models, simple copying might not work # In that case, use the original model (not ideal but will work as a fallback) guided_model = model guided_model.eval() # Replace ReLU with Guided ReLU replace_relu_with_guided_relu(guided_model) return guided_model
[docs] def guided_backprop(model, input_tensor, target_class=None): """Generate guided backpropagation attribution map. Args: model: PyTorch model input_tensor: Input tensor target_class: Target class index (None for argmax) Returns: Gradient attribution map """ # Ensure input has gradient input_tensor = input_tensor.clone().detach().requires_grad_(True) # Forward pass model.zero_grad() # Run model with input output = model(input_tensor) # Select target class if target_class is None: target_class = output.argmax(dim=1) elif isinstance(target_class, int): target_class = torch.tensor([target_class]).to(input_tensor.device) elif isinstance(target_class, torch.Tensor) and target_class.numel() == 1 and target_class.ndim == 0: # Handle scalar tensor target_class = target_class.unsqueeze(0) # Create one-hot encoding for target(s) one_hot = torch.zeros_like(output) # Handle both batch and single examples if one_hot.shape[0] > 1 and isinstance(target_class, torch.Tensor) and target_class.shape[0] == one_hot.shape[0]: # Batch case with target for each example for i, t in enumerate(target_class): one_hot[i, t] = 1.0 else: # Single target for all examples in batch one_hot.scatter_(1, target_class.view(-1, 1), 1.0) # Backward pass output.backward(gradient=one_hot) # Get gradients gradients = input_tensor.grad.clone() # Apply small value thresholding for numerical stability # This helps ensure outputs match between TensorFlow and PyTorch gradients[torch.abs(gradients) < 1e-10] = 0.0 return gradients
[docs] class GuidedBackprop: """Class-based implementation of Guided Backpropagation."""
[docs] def __init__(self, model): """Initialize Guided Backpropagation with the model. Args: model: PyTorch model """ self.model = model self.guided_model = build_guided_model(model) self._hooks = [] # For compatibility with tests
[docs] def attribute(self, inputs, target=None): """Calculate attribution using Guided Backpropagation. Args: inputs: Input tensor target: Target class index (None for argmax) Returns: Attribution tensor of the same shape as inputs """ return guided_backprop(self.guided_model, inputs, target_class=target)
[docs] class DeconvNet: """Class-based implementation of DeconvNet."""
[docs] def __init__(self, model): """Initialize DeconvNet with the model. Args: model: PyTorch model """ from .deconvnet import build_deconvnet_model, deconvnet self.model = model self.deconvnet_model = build_deconvnet_model(model) if hasattr(model, 'state_dict') else model self._hooks = [] # For compatibility with tests self._deconvnet_fn = deconvnet
[docs] def attribute(self, inputs, target=None): """Calculate attribution using DeconvNet. Args: inputs: Input tensor target: Target class index (None for argmax) Returns: Attribution tensor of the same shape as inputs """ return self._deconvnet_fn(self.deconvnet_model, inputs, target_class=target)