Source code for signxai.tf_signxai.methods

# signxai/tf_signxai/methods.py
"""
Refactored TensorFlow explanation methods with a unified execution entry point.
"""
import numpy as np
import tensorflow as tf
from typing import Dict, Any, List

from signxai.tf_signxai.methods_impl.grad_cam import calculate_grad_cam_relevancemap, calculate_grad_cam_relevancemap_timeseries
from signxai.tf_signxai.methods_impl.guided_backprop import guided_backprop_on_guided_model
from signxai.tf_signxai.methods_impl.signed import calculate_sign_mu
from signxai.tf_signxai.tf_utils import calculate_explanation_innvestigate
from signxai.common.method_parser import MethodParser
from signxai.common.method_normalizer import MethodNormalizer

# A registry to map base method names to their core implementation functions.
METHOD_IMPLEMENTATIONS = {}


[docs] def register_method(name): """Decorator to register a method implementation.""" def decorator(func): METHOD_IMPLEMENTATIONS[name] = func return func return decorator
# --- Core Method Implementations --- @register_method("gradient") def _gradient(model, x, **kwargs): return calculate_explanation_innvestigate(model, x, method='gradient', **kwargs) @register_method("smoothgrad") def _smoothgrad(model, x, **kwargs): params = {**MethodNormalizer.METHOD_PRESETS['smoothgrad'], **kwargs} return calculate_explanation_innvestigate(model, x, method='smoothgrad', **params) @register_method("integrated_gradients") def _integrated_gradients(model, x, **kwargs): params = {**MethodNormalizer.METHOD_PRESETS['integrated_gradients'], **kwargs} return calculate_explanation_innvestigate(model, x, method='integrated_gradients', **params) @register_method("guided_backprop") def _guided_backprop(model, x, **kwargs): return calculate_explanation_innvestigate(model, x, method='guided_backprop', **kwargs) @register_method("deconvnet") def _deconvnet(model, x, **kwargs): return calculate_explanation_innvestigate(model, x, method='deconvnet', **kwargs) @register_method("grad_cam") def _grad_cam(model, x, **kwargs): if x.ndim <= 3: # Assuming timeseries return calculate_grad_cam_relevancemap_timeseries(x, model, **kwargs) else: return calculate_grad_cam_relevancemap(x, model, **kwargs) @register_method("lrp") def _lrp(model, x, **kwargs): """ Unified LRP implementation for TensorFlow using iNNvestigate. """ # Extract rule parameter and remove it from kwargs to avoid passing it to iNNvestigate rule = kwargs.pop('rule', 'epsilon') # Handle rule names that contain the full method name if rule.startswith('lrp'): # Extract just the rule part (e.g., 'lrp_epsilon_50_x_sign' -> 'epsilon') rule_parts = rule.split('_') if len(rule_parts) > 1: rule = rule_parts[1] # Get the actual rule name # iNNvestigate uses dot notation for LRP methods method_name = f"lrp.{rule}" return calculate_explanation_innvestigate(model, x, method=method_name, **kwargs) # --- Modifier Application --- def _apply_modifiers(relevance_map: np.ndarray, x: np.ndarray, modifiers: List[str], params: Dict[str, Any]) -> np.ndarray: """ Applies a chain of modifiers to a relevance map. """ if not modifiers: return relevance_map modified_map = relevance_map.copy() for modifier in modifiers: if modifier == 'x_input' or modifier == 'input': modified_map *= x elif modifier == 'x_sign' or modifier == 'sign': # Check if there's a mu parameter for sign if 'mu' in params: mu = params.get('mu', 0.0) # Debug output import logging logger = logging.getLogger(__name__) logger.debug(f"Applying sign with mu={mu}") modified_map *= calculate_sign_mu(x, mu) else: s = np.nan_to_num(x / np.abs(x), nan=1.0) modified_map *= s elif modifier == 'std_x': # Standard deviation normalization pass # Implementation needed if used return modified_map # --- Main Execution Function ---
[docs] def execute(model, x, parsed_method: Dict[str, Any], **kwargs) -> np.ndarray: """ Executes the specified XAI method for TensorFlow. """ base_method = MethodNormalizer.normalize(parsed_method['base_method'], 'tensorflow') all_params = { **MethodNormalizer.METHOD_PRESETS.get(base_method, {}), **parsed_method['params'], **kwargs } if base_method not in METHOD_IMPLEMENTATIONS: if base_method.startswith('lrp'): base_method = 'lrp' all_params['rule'] = parsed_method['original_name'].split('.')[-1] else: raise ValueError(f"Method '{base_method}' is not implemented for TensorFlow.") core_method_func = METHOD_IMPLEMENTATIONS[base_method] # Prepare input - add batch dimension if necessary x_input = x # For images: (H, W, C) needs batch -> (1, H, W, C) # For time series: (T, F) needs batch -> (1, T, F) needs_batch_dim = (x.ndim == 3 and x.shape[-1] <= 4) or x.ndim == 2 if needs_batch_dim: x_input = np.expand_dims(x, axis=0) relevance_map_np = core_method_func(model, x_input, **all_params) # Remove batch dimension if it was added if needs_batch_dim and relevance_map_np.ndim == 4 and relevance_map_np.shape[0] == 1: relevance_map_np = relevance_map_np[0] final_map = _apply_modifiers(relevance_map_np, x, parsed_method['modifiers'], all_params) return final_map