Source code for signxai.api

"""
Unified API for SignXAI - Cross-framework XAI explanations.

This module provides a unified interface for generating explanations
across TensorFlow and PyTorch frameworks with automatic parameter mapping
and framework detection.
"""

import numpy as np
from typing import Union, Optional, Any, Dict


[docs] def explain( model, x: Union[np.ndarray, "torch.Tensor", "tf.Tensor"], method_name: str, target_class: Optional[int] = None, framework: Optional[str] = None, **kwargs ) -> np.ndarray: """ Generate explanations for model predictions using various XAI methods. This unified API automatically handles framework detection, model preparation, and parameter mapping to provide consistent explanations across TensorFlow and PyTorch implementations. Args: model: The model to explain (TensorFlow Keras model or PyTorch nn.Module) x: Input data as numpy array or framework tensor method_name: Name of the XAI method to apply. Supported methods include: - Gradient-based: 'gradient', 'smoothgrad', 'integrated_gradients', 'vargrad' - Backprop methods: 'guided_backprop', 'deconvnet' - Feature methods: 'gradcam' - LRP methods: 'lrp_epsilon', 'lrp_alpha_1_beta_0', 'lrp_alpha_2_beta_1' - And many more (see documentation for full list) target_class: Target class index for explanation. If None, uses predicted class. framework: Framework to use ('tensorflow' or 'pytorch'). If None, auto-detected. **kwargs: Method-specific parameters. Common parameters: - steps: Number of steps for Integrated Gradients (default: 50) - num_samples: Number of samples for SmoothGrad (default: 25) - noise_level: Noise level for SmoothGrad (default: 0.1) - layer_name: Target layer for Grad-CAM (framework-specific) - epsilon: Epsilon value for LRP methods (default: 0.1) - alpha, beta: Alpha-beta values for LRP alpha-beta methods Returns: Explanation/relevance map as numpy array with same spatial dimensions as input Raises: ValueError: If framework cannot be detected or is unsupported ImportError: If required framework dependencies are not installed Examples: Basic gradient explanation: >>> explanation = explain(model, image, 'gradient') Integrated Gradients with custom steps: >>> explanation = explain(model, image, 'integrated_gradients', steps=100) Grad-CAM on specific layer: >>> explanation = explain(model, image, 'gradcam', layer_name='block5_conv3') LRP with epsilon rule: >>> explanation = explain(model, image, 'lrp_epsilon', epsilon=0.1) Cross-framework usage (same API for both): >>> tf_explanation = explain(tf_model, data, 'smoothgrad', framework='tensorflow') >>> pt_explanation = explain(pt_model, data, 'smoothgrad', framework='pytorch') """ # Import here to avoid circular imports from . import _detect_framework, _prepare_model, _prepare_input from . import _get_predicted_class, _map_parameters from . import _call_tensorflow_method, _call_pytorch_method from . import _load_tf_signxai, _load_torch_signxai # Framework detection if not specified if framework is None: framework = _detect_framework(model) if framework is None: raise ValueError( "Could not detect framework. Please specify framework='tensorflow' or framework='pytorch'" ) framework = framework.lower() if framework not in ['tensorflow', 'pytorch']: raise ValueError("Framework must be 'tensorflow' or 'pytorch'") # Ensure the framework is available if framework == 'tensorflow': tf_module = _load_tf_signxai() if tf_module is None: raise ImportError("TensorFlow not available. Install with: pip install signxai[tensorflow]") elif framework == 'pytorch': torch_module = _load_torch_signxai() if torch_module is None: raise ImportError("PyTorch not available. Install with: pip install signxai[pytorch]") # Prepare model (ensure no softmax for explanations) try: prepared_model = _prepare_model(model, framework) except Exception as e: print(f"Warning: Could not remove softmax from model: {e}") prepared_model = model # Prepare input data prepared_input = _prepare_input(x, framework) # Handle target class if target_class is None: target_class = _get_predicted_class(prepared_model, prepared_input, framework) # Map common parameters between frameworks mapped_kwargs = _map_parameters(method_name, framework, **kwargs) # Call framework-specific implementation if framework == 'tensorflow': return _call_tensorflow_method(prepared_model, prepared_input, method_name, target_class, **mapped_kwargs) else: # pytorch return _call_pytorch_method(prepared_model, prepared_input, method_name, target_class, **mapped_kwargs)
[docs] def list_methods(framework: Optional[str] = None) -> Dict[str, list]: """ List all available XAI methods for the specified framework(s). Args: framework: Framework to list methods for ('tensorflow', 'pytorch', or None for both) Returns: Dictionary with framework names as keys and list of method names as values """ import inspect from . import _load_tf_signxai, _load_torch_signxai methods = {} if framework is None or framework.lower() == 'tensorflow': tf_module = _load_tf_signxai() if tf_module is not None: try: import signxai.tf_signxai.methods.wrappers as tf_wrappers tf_methods = [name for name, obj in inspect.getmembers(tf_wrappers) if inspect.isfunction(obj) and not name.startswith('_') and not name.startswith('calculate_native')] methods['tensorflow'] = sorted(tf_methods) except Exception as e: methods['tensorflow'] = f"Error loading TensorFlow methods: {e}" if framework is None or framework.lower() == 'pytorch': torch_module = _load_torch_signxai() if torch_module is not None: try: import signxai.torch_signxai.methods.wrappers as pt_wrappers pt_methods = [name for name, obj in inspect.getmembers(pt_wrappers) if inspect.isfunction(obj) and not name.startswith('_') and name not in ['calculate_relevancemap', 'calculate_relevancemaps']] methods['pytorch'] = sorted(pt_methods) except Exception as e: methods['pytorch'] = f"Error loading PyTorch methods: {e}" return methods
[docs] def get_method_info(method_name: str, framework: Optional[str] = None) -> Dict[str, Any]: """ Get detailed information about a specific XAI method. Args: method_name: Name of the method to get info for framework: Framework to check ('tensorflow', 'pytorch', or None for both) Returns: Dictionary with method information including parameters, description, etc. """ import inspect from . import _load_tf_signxai, _load_torch_signxai info = {'method_name': method_name, 'available_in': []} # Check TensorFlow if framework is None or framework.lower() == 'tensorflow': tf_module = _load_tf_signxai() if tf_module is not None: try: import signxai.tf_signxai.methods.wrappers as tf_wrappers if hasattr(tf_wrappers, method_name): func = getattr(tf_wrappers, method_name) info['available_in'].append('tensorflow') info['tensorflow'] = { 'signature': str(inspect.signature(func)), 'docstring': inspect.getdoc(func) or "No documentation available" } except Exception as e: info['tensorflow_error'] = str(e) # Check PyTorch if framework is None or framework.lower() == 'pytorch': torch_module = _load_torch_signxai() if torch_module is not None: try: import signxai.torch_signxai.methods.wrappers as pt_wrappers if hasattr(pt_wrappers, method_name): func = getattr(pt_wrappers, method_name) info['available_in'].append('pytorch') info['pytorch'] = { 'signature': str(inspect.signature(func)), 'docstring': inspect.getdoc(func) or "No documentation available" } except Exception as e: info['pytorch_error'] = str(e) return info
# Common method parameter presets for easy use METHOD_PRESETS = { 'gradient': {}, 'smoothgrad': {'num_samples': 25, 'noise_level': 0.1}, 'integrated_gradients': {'steps': 50}, 'vargrad': {'num_samples': 25, 'noise_level': 0.2}, 'guided_backprop': {}, 'deconvnet': {}, 'gradcam': {}, # Requires layer_name 'lrp_epsilon': {'epsilon': 0.1}, 'lrp_alpha_1_beta_0': {'alpha': 1.0, 'beta': 0.0}, 'lrp_alpha_2_beta_1': {'alpha': 2.0, 'beta': 1.0}, }
[docs] def explain_with_preset(model, x, method_name: str, preset: str = 'default', **override_kwargs): """ Explain using predefined parameter presets for common use cases. Args: model: Model to explain x: Input data method_name: XAI method name preset: Preset name ('default', 'fast', 'high_quality') **override_kwargs: Parameters to override preset values Returns: Explanation as numpy array """ # Get base parameters for method base_params = METHOD_PRESETS.get(method_name, {}).copy() # Apply preset modifications if preset == 'fast': if method_name == 'smoothgrad': base_params.update({'num_samples': 10, 'noise_level': 0.15}) elif method_name == 'integrated_gradients': base_params.update({'steps': 20}) elif preset == 'high_quality': if method_name == 'smoothgrad': base_params.update({'num_samples': 50, 'noise_level': 0.05}) elif method_name == 'integrated_gradients': base_params.update({'steps': 100}) # Apply any user overrides base_params.update(override_kwargs) return explain(model, x, method_name, **base_params)