Source code for signxai.common.method_families

"""
Method Family Architecture for SignXAI2

This module implements a family-based approach for XAI methods, grouping
genuinely similar methods while preserving complex method-specific logic.
"""

import os
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, Callable
import logging

logger = logging.getLogger(__name__)


[docs] class MethodFamily(ABC): """Base class for all method families."""
[docs] def __init__(self): self.supported_methods = set() self.framework_handlers = {}
[docs] @abstractmethod def can_handle(self, method_name: str) -> bool: """Check if this family can handle the given method.""" pass
[docs] @abstractmethod def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute method for TensorFlow.""" pass
[docs] @abstractmethod def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute method for PyTorch.""" pass
[docs] def execute(self, model, x, method_name: str, framework: str, **kwargs): """Main execution entry point.""" if framework == 'tensorflow': return self.execute_tensorflow(model, x, method_name, **kwargs) elif framework == 'pytorch': return self.execute_pytorch(model, x, method_name, **kwargs) else: raise ValueError(f"Unsupported framework: {framework}")
[docs] class SimpleGradientFamily(MethodFamily): """ Handles basic gradient-based methods that are truly similar. Safe for consolidation with minimal risk. """
[docs] def __init__(self): super().__init__() self.supported_methods = { 'gradient', 'gradient_x_input', 'gradient_x_sign', 'gradient_x_input_x_sign', 'gradient_x_sign_mu', 'input_t_gradient' # This is gradient x input }
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is a simple gradient method.""" method_lower = method_name.lower() base = method_lower.split('_')[0] # Handle both 'gradient' and 'input_t_gradient' return base == 'gradient' or method_lower == 'input_t_gradient' or method_lower in self.supported_methods
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute gradient methods for TensorFlow with dynamic modifiers.""" try: from ..tf_signxai.tf_utils import calculate_explanation_innvestigate import numpy as np method_lower = method_name.lower() # Determine base method - handle variations dynamically if method_lower.startswith('gradient') or 'gradient' in method_lower: base_method = 'gradient' elif method_lower == 'input_t_gradient': base_method = 'input_t_gradient' else: base_method = 'gradient' # Get base result using iNNvestigate result = calculate_explanation_innvestigate( model, x, method=base_method, neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **{k: v for k, v in kwargs.items() if k not in ['target_class', 'neuron_selection', 'modifier']} ) # Apply modifiers dynamically - from method name OR kwargs modifiers = kwargs.get('modifier', '') # Handle input_t_gradient (which is gradient * input) if method_lower == 'input_t_gradient': result = result * x # Check method name for modifiers elif '_x_input' in method_lower or 'input' in modifiers: result = result * x if '_x_sign' in method_lower or 'sign' in modifiers: # Use mu from kwargs if available (parsed by MethodParser) if 'mu' in kwargs: mu = kwargs.get('mu', 0.0) from ..tf_signxai.methods_impl.signed import calculate_sign_mu result = result * calculate_sign_mu(x, mu) else: # Simple sign result = result * np.sign(x) # Handle std_x modifier if 'std_x' in method_lower or 'std' in modifiers: std = np.std(x) if std > 0: result = result / std return result except Exception as e: logger.warning(f"SimpleGradientFamily failed for {method_name}: {e}") raise
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute gradient methods for PyTorch.""" try: import torch from ..torch_signxai.methods_impl.zennit_impl.analyzers import GradientAnalyzer # Get base gradient analyzer = GradientAnalyzer(model) gradient = analyzer.analyze(x, kwargs.get('target_class')) # Convert to tensor for operations gradient_tensor = torch.from_numpy(gradient) if not isinstance(gradient, torch.Tensor) else gradient x_tensor = x if isinstance(x, torch.Tensor) else torch.from_numpy(x) # Apply modifiers method_lower = method_name.lower() # Handle input_t_gradient (which is gradient * input) if method_lower == 'input_t_gradient' or 'input' in method_lower: gradient_tensor = gradient_tensor * x_tensor if 'sign' in method_lower: # Use mu from kwargs if available (parsed by MethodParser) if 'mu' in kwargs: mu = kwargs.get('mu', 0.0) from ..torch_signxai.methods_impl.signed import calculate_sign_mu sign_mu = calculate_sign_mu(x_tensor.detach().cpu().numpy(), mu) gradient_tensor = gradient_tensor * torch.from_numpy(sign_mu) else: gradient_tensor = gradient_tensor * torch.sign(x_tensor) return gradient_tensor.detach().cpu().numpy() except Exception as e: logger.warning(f"SimpleGradientFamily failed for {method_name}: {e}") raise
[docs] class StochasticMethodFamily(MethodFamily): """ Handles noise-based attribution methods. Moderate risk - requires careful parameter handling. """
[docs] def __init__(self): super().__init__() self.supported_methods = { 'smoothgrad', 'vargrad', 'integrated_gradients', 'smoothgrad_x_input', 'smoothgrad_x_sign', 'vargrad_x_input', 'vargrad_x_sign' }
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is a stochastic method.""" method_lower = method_name.lower() # Check for base methods if any(method_lower.startswith(prefix) for prefix in ['smoothgrad', 'vargrad']): return True # Check for integrated gradients variations if 'integrated' in method_lower or 'integratedgradients' in method_lower: return True return False
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute stochastic methods for TensorFlow with dynamic modifiers.""" try: from ..tf_signxai.tf_utils import calculate_explanation_innvestigate import numpy as np method_lower = method_name.lower() # Determine base method if 'smoothgrad' in method_lower: method_for_innvestigate = 'smoothgrad' kwargs['augment_by_n'] = kwargs.get('augment_by_n', kwargs.get('num_samples', 50)) kwargs['noise_scale'] = kwargs.get('noise_scale', kwargs.get('noise_level', 0.1)) elif 'vargrad' in method_lower: # VarGrad - use variance of gradients method_for_innvestigate = 'vargrad' kwargs['augment_by_n'] = kwargs.get('augment_by_n', kwargs.get('num_samples', 50)) kwargs['noise_scale'] = kwargs.get('noise_scale', kwargs.get('noise_level', 0.1)) elif 'integrated' in method_lower or 'integratedgradients' in method_lower: method_for_innvestigate = 'integrated_gradients' kwargs['steps'] = kwargs.get('steps', kwargs.get('ig_steps', 64)) else: # Default to base method extraction base = method_lower.split('_')[0] if base in ['smoothgrad', 'vargrad', 'integrated']: method_for_innvestigate = base else: raise ValueError(f"Unknown stochastic method: {method_lower}") # Call iNNvestigate try: result = calculate_explanation_innvestigate( model, x, method=method_for_innvestigate, neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **{k: v for k, v in kwargs.items() if k not in ['target_class', 'neuron_selection', 'modifier']} ) except (ValueError, Exception) as e: # If it's a shape issue with integrated_gradients, try adding batch dimension if 'integrated' in method_lower: import numpy as np # Ensure proper batch dimension if x.ndim == 3: # (H, W, C) -> (1, H, W, C) x_batched = np.expand_dims(x, axis=0) elif x.ndim == 2: # (H, W) -> (1, H, W, 1) x_batched = np.expand_dims(np.expand_dims(x, axis=0), axis=-1) else: x_batched = x try: result = calculate_explanation_innvestigate( model, x_batched, method=method_for_innvestigate, neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **{k: v for k, v in kwargs.items() if k not in ['target_class', 'neuron_selection', 'modifier']} ) # Remove batch dimension if added if result.ndim == 4 and result.shape[0] == 1: result = result[0] elif result.ndim == 3 and result.shape[0] == 1: result = result[0] except: # If still fails, raise the original error raise e else: raise # Apply modifiers dynamically modifiers = kwargs.get('modifier', '') if '_x_input' in method_lower or 'input' in modifiers: result = result * x if '_x_sign' in method_lower or 'sign' in modifiers: # Use mu from kwargs if available (parsed by MethodParser) if 'mu' in kwargs: mu = kwargs.get('mu', 0.0) from ..tf_signxai.methods_impl.signed import calculate_sign_mu result = result * calculate_sign_mu(x, mu) else: result = result * np.sign(x) if 'std_x' in method_lower or 'std' in modifiers: std = np.std(x) if std > 0: result = result / std return result except Exception as e: logger.warning(f"StochasticMethodFamily failed for {method_name}: {e}") raise
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute stochastic methods for PyTorch.""" try: import torch import numpy as np method_lower = method_name.lower() base_method = method_lower.split('_')[0] if base_method == 'smoothgrad': from ..torch_signxai.methods_impl.zennit_impl.analyzers import SmoothGradAnalyzer noise_level = kwargs.get('noise_level', kwargs.get('noise_scale', 0.1)) num_samples = kwargs.get('num_samples', kwargs.get('augment_by_n', 50)) analyzer = SmoothGradAnalyzer(model, noise_level, num_samples) elif base_method == 'vargrad': from ..torch_signxai.methods_impl.zennit_impl.analyzers import VarGradAnalyzer noise_level = kwargs.get('noise_level', kwargs.get('noise_scale', 0.1)) num_samples = kwargs.get('num_samples', kwargs.get('augment_by_n', 50)) analyzer = VarGradAnalyzer(model, noise_level, num_samples) elif base_method in ['integrated', 'integratedgradients']: from ..torch_signxai.methods_impl.zennit_impl.analyzers import IntegratedGradientsAnalyzer steps = kwargs.get('ig_steps', kwargs.get('steps', 64)) baseline = kwargs.get('baseline', kwargs.get('reference_inputs')) analyzer = IntegratedGradientsAnalyzer(model, steps, baseline) else: raise ValueError(f"Unknown stochastic method: {base_method}") result = analyzer.analyze(x, kwargs.get('target_class')) # Ensure numpy array if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() # Apply modifiers if 'input' in method_lower: x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x result = result * x_np if 'sign' in method_lower: x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x # Use mu from kwargs if available (parsed by MethodParser) if 'mu' in kwargs: mu = kwargs.get('mu', 0.0) from ..torch_signxai.methods_impl.signed import calculate_sign_mu result = result * calculate_sign_mu(x_np, mu) else: result = result * np.sign(x_np) return result except Exception as e: logger.warning(f"StochasticMethodFamily failed for {method_name}: {e}") raise
[docs] class LRPBasicFamily(MethodFamily): """ Handles basic LRP methods with simple epsilon/alpha-beta rules. Moderate risk - requires careful rule handling. """
[docs] def __init__(self): super().__init__() # Generate comprehensive LRP method variations self.supported_methods = set() # Basic LRP rules self.supported_methods.update([ 'lrp', # Basic LRP without any suffix 'lrp_epsilon', 'lrp_alpha_beta', 'lrp_z', 'lrp_flat', 'lrp_w_square', 'lrp_zplus', 'lrp_gamma', 'lrp_alpha_1_beta_0', 'lrp_alpha_2_beta_1' # Common alpha-beta combinations ]) # Base LRP methods - parameters will be extracted dynamically by MethodParser self.supported_methods.add('lrp_epsilon') # Dynamic parameter extraction self.supported_methods.add('lrp_alpha_beta') # Dynamic parameter extraction # Common presets for compatibility self.supported_methods.add('lrp_alpha_1_beta_0') self.supported_methods.add('lrp_alpha_2_beta_1') # Flat and W-square base methods for rule in ['flat', 'w_square']: self.supported_methods.add(f'lrp_{rule}') # Gamma base method (parameters handled dynamically) self.supported_methods.add('lrp_gamma')
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is a basic LRP method.""" method_lower = method_name.lower() # Check if it's in our pre-generated supported methods if method_lower in self.supported_methods: return True # Dynamic parsing for LRP methods that match our patterns if method_lower.startswith('lrp_'): # Extract rule type parts = method_lower.split('_') if len(parts) >= 2: rule = parts[1] # Handle epsilon rules with values if rule == 'epsilon' and len(parts) >= 3: try: # Try to parse epsilon value (handle dots as underscores) eps_str = '_'.join(parts[2:]).replace('std_x', '').replace('_', '.') if eps_str.endswith('.'): eps_str = eps_str[:-1] float(eps_str) return True except (ValueError, IndexError): pass # Handle alpha_beta rules with values elif rule == 'alpha' and 'beta' in method_lower: return True # Handle other basic rules elif rule in ['z', 'flat', 'w', 'gamma', 'zplus']: return True return False
def _parse_lrp_method(self, method_lower): """Parse LRP method name to extract rule type, parameters, and modifiers.""" parts = method_lower.split('_') # Check for std_x modifier std_x_modifier = 'std_x' in method_lower # Remove lrp prefix if parts[0] == 'lrp': parts = parts[1:] # Handle bare 'lrp' method (default to epsilon rule) if not parts: return 'epsilon', {'epsilon': 0.01}, std_x_modifier rule_type = parts[0] if parts else 'epsilon' rule_params = {} # Parse different rule types if rule_type == 'epsilon': # Extract epsilon value if len(parts) > 1 and not parts[1] == 'std': epsilon_str = '_'.join(parts[1:]).replace('std_x', '').replace('_', '.') if epsilon_str.endswith('.'): epsilon_str = epsilon_str[:-1] try: rule_params['epsilon'] = float(epsilon_str) except (ValueError, IndexError): rule_params['epsilon'] = 0.01 else: rule_params['epsilon'] = 0.01 elif rule_type == 'alpha': # Parse alpha and beta values alpha_idx = method_lower.find('alpha_') beta_idx = method_lower.find('beta_') if alpha_idx != -1 and beta_idx != -1: # Extract alpha value alpha_start = alpha_idx + 6 # Length of 'alpha_' alpha_end = beta_idx - 1 alpha_str = method_lower[alpha_start:alpha_end].replace('_', '.') # Extract beta value beta_start = beta_idx + 5 # Length of 'beta_' beta_str = method_lower[beta_start:].replace('_std_x', '').replace('_', '.') try: rule_params['alpha'] = float(alpha_str) rule_params['beta'] = float(beta_str) except (ValueError, IndexError): rule_params['alpha'] = 1.0 rule_params['beta'] = 0.0 else: rule_params['alpha'] = 1.0 rule_params['beta'] = 0.0 elif rule_type == 'gamma': # Extract gamma value if len(parts) > 1: gamma_str = '_'.join(parts[1:]).replace('std_x', '').replace('_', '.') if gamma_str.endswith('.'): gamma_str = gamma_str[:-1] try: rule_params['gamma'] = float(gamma_str) except (ValueError, IndexError): rule_params['gamma'] = 0.25 else: rule_params['gamma'] = 0.25 elif rule_type in ['flat', 'w_square', 'w', 'z', 'zplus']: # These rules might have epsilon parameters if 'epsilon' in method_lower: epsilon_idx = method_lower.find('epsilon_') if epsilon_idx != -1: epsilon_start = epsilon_idx + 8 # Length of 'epsilon_' epsilon_str = method_lower[epsilon_start:].replace('std_x', '').replace('_', '.') if epsilon_str.endswith('.'): epsilon_str = epsilon_str[:-1] try: rule_params['epsilon'] = float(epsilon_str) except (ValueError, IndexError): rule_params['epsilon'] = 0.01 return rule_type, rule_params, std_x_modifier def _get_innvestigate_method_and_params(self, rule_type, rule_params, original_kwargs): """Map parsed rule to iNNvestigate method and parameters.""" method_kwargs = original_kwargs.copy() # Remove our custom parameters to avoid conflicts method_kwargs.pop('target_class', None) method_kwargs.pop('neuron_selection', None) if rule_type == 'epsilon': method_for_innvestigate = 'lrp.epsilon' method_kwargs['epsilon'] = rule_params.get('epsilon', 0.01) elif rule_type == 'alpha': method_for_innvestigate = 'lrp.alpha_beta' method_kwargs['alpha'] = rule_params.get('alpha', 1.0) method_kwargs['beta'] = rule_params.get('beta', 0.0) elif rule_type == 'gamma': method_for_innvestigate = 'lrp.gamma' method_kwargs['gamma'] = rule_params.get('gamma', 0.25) elif rule_type == 'flat': method_for_innvestigate = 'lrp.flat' elif rule_type in ['w_square', 'w']: method_for_innvestigate = 'lrp.w_square' elif rule_type in ['z', 'zplus']: method_for_innvestigate = 'lrp.z' else: # Default fallback method_for_innvestigate = 'lrp.epsilon' method_kwargs['epsilon'] = 0.01 return method_for_innvestigate, method_kwargs
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute LRP methods for TensorFlow with comprehensive rule parsing.""" try: from ..tf_signxai.tf_utils import calculate_explanation_innvestigate import numpy as np method_lower = method_name.lower() # Handle dot notation methods directly (e.g., lrp.z, lrp.epsilon) if '.' in method_lower: # Direct pass-through for dot notation result = calculate_explanation_innvestigate( model, x, method=method_lower, neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **{k: v for k, v in kwargs.items() if k not in ['target_class', 'neuron_selection']} ) return result parts = method_lower.split('_') # Parse method components rule_type, rule_params, std_x_modifier = self._parse_lrp_method(method_lower) # Map to iNNvestigate method and set parameters method_for_innvestigate, method_kwargs = self._get_innvestigate_method_and_params( rule_type, rule_params, kwargs ) # Execute the method result = calculate_explanation_innvestigate( model, x, method=method_for_innvestigate, neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **method_kwargs ) # Apply std_x modifier if present if std_x_modifier: std = np.std(x) if std > 0: result = result / std return result except Exception as e: logger.warning(f"LRPBasicFamily failed for {method_name}: {e}") raise
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute LRP methods for PyTorch with comprehensive rule parsing.""" try: import torch import numpy as np from ..torch_signxai.methods_impl.zennit_impl.analyzers import LRPAnalyzer method_lower = method_name.lower() # Convert dot notation to underscore notation for PyTorch if '.' in method_lower: # Map dot notation to underscore (e.g., lrp.z -> lrp_z) method_lower = method_lower.replace('.', '_') # Handle special cases if method_lower == 'lrp_alpha_beta': method_lower = 'lrp_alpha_1_beta_0' # Default alpha-beta elif method_lower == 'lrp_sequential_composite_a': method_lower = 'lrp_sequential_composite_a' elif method_lower == 'lrp_sequential_composite_b': method_lower = 'lrp_sequential_composite_b' # Handle IB variants (Input Bounded) elif '_ib' in method_lower: method_lower = method_lower.replace('_ib', '') # Remove IB suffix # Parse method components using the same logic as TensorFlow rule_type, rule_params, std_x_modifier = self._parse_lrp_method(method_lower) # Map to PyTorch analyzer if rule_type == 'epsilon': analyzer = LRPAnalyzer(model, 'epsilon', rule_params.get('epsilon', 0.01)) elif rule_type == 'alpha': analyzer = LRPAnalyzer( model, 'alphabeta', alpha=rule_params.get('alpha', 1.0), beta=rule_params.get('beta', 0.0) ) elif rule_type == 'gamma': # For gamma, we use epsilon rule with gamma parameter if available # or fall back to a composite rule try: from ..torch_signxai.methods_impl.zennit_impl.analyzers import AdvancedLRPAnalyzer analyzer = AdvancedLRPAnalyzer( model, 'gamma', gamma=rule_params.get('gamma', 0.25) ) except ImportError: # Fallback to epsilon analyzer = LRPAnalyzer(model, 'epsilon', 0.01) elif rule_type == 'flat': try: from ..torch_signxai.methods_impl.zennit_impl.analyzers import AdvancedLRPAnalyzer analyzer = AdvancedLRPAnalyzer(model, 'flat') except ImportError: # Fallback to epsilon analyzer = LRPAnalyzer(model, 'epsilon', 0.01) elif rule_type in ['w_square', 'w']: try: from ..torch_signxai.methods_impl.zennit_impl.analyzers import AdvancedLRPAnalyzer analyzer = AdvancedLRPAnalyzer(model, 'wsquare') except ImportError: # Fallback to epsilon analyzer = LRPAnalyzer(model, 'epsilon', 0.01) elif rule_type in ['z', 'zplus']: analyzer = LRPAnalyzer(model, 'zplus') else: # Default fallback analyzer = LRPAnalyzer(model, 'epsilon', 0.01) result = analyzer.analyze(x, kwargs.get('target_class')) # Ensure numpy array if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() # Apply std_x modifier if present if std_x_modifier: x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x std = np.std(x_np) if std > 0: result = result / std return result except Exception as e: logger.warning(f"LRPBasicFamily failed for {method_name}: {e}") raise
[docs] class SpecializedLRPFamily(MethodFamily): """ Handles complex LRP methods that need special handling. Higher risk - these methods have complex requirements. """
[docs] def __init__(self): super().__init__() # Generate comprehensive specialized LRP method variations self.supported_methods = set() # Basic specialized methods self.supported_methods.update([ 'lrp_flat', 'lrp_w_square', 'lrp_gamma', 'lrp_gamma_0_25', 'lrp_sequential_composite_a', 'lrp_sequential_composite_b', 'deep_taylor', 'deep_taylor_bounded', 'pattern_attribution', 'pattern_net' ]) # Advanced LRP methods - parameters will be extracted dynamically by MethodParser # LRP Sign variations self.supported_methods.add('lrpsign_epsilon') # Dynamic parameter extraction self.supported_methods.add('lrpsign_alpha_beta') # Dynamic parameter extraction self.supported_methods.add('lrpsign_alpha_1_beta_0') # Common preset self.supported_methods.add('lrpsign_alpha_2_beta_1') # Common preset # LRP Z variations self.supported_methods.add('lrpz_epsilon') # Dynamic parameter extraction self.supported_methods.add('lrpz_sequential_composite_a') self.supported_methods.add('lrpz_sequential_composite_b') # Flat LRP variations self.supported_methods.add('flatlrp_epsilon') # Dynamic parameter extraction self.supported_methods.add('flatlrp_alpha_beta') # Dynamic parameter extraction # W-square LRP variations self.supported_methods.add('w2lrp_epsilon') # Dynamic parameter extraction self.supported_methods.add('w2lrp_alpha_beta') # Dynamic parameter extraction # Z-Box LRP variations self.supported_methods.add('zblrp_epsilon') # Dynamic parameter extraction self.supported_methods.add('zblrp_sequential_composite_a') self.supported_methods.add('zblrp_sequential_composite_b')
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is a specialized LRP method.""" method_lower = method_name.lower() # Direct match with our supported methods if method_lower in self.supported_methods: return True # Check for all LRP sign variants with dynamic parsing specialized_prefixes = ['lrpsign_', 'lrpz_', 'flatlrp_', 'w2lrp_', 'zblrp_'] # Also handle VGG16ILSVRC suffixes for prefix in specialized_prefixes: if method_lower.startswith(prefix): return True # Handle VGG16ILSVRC variants (e.g., zblrp_epsilon_0_1_VGG16ILSVRC) if prefix[:-1] in method_lower and 'vgg16ilsvrc' in method_lower: return True # Check for complex LRP variants that need special handling if 'lrp' in method_lower and any(x in method_lower for x in ['flat', 'w_square', 'gamma', 'sequential', 'composite']): return True # Check for Deep Taylor and Pattern methods if any(method_lower.startswith(method) for method in ['deep_taylor', 'pattern_attribution', 'pattern_net']): return True return False
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute specialized LRP methods for TensorFlow with sign variants.""" try: from ..tf_signxai.tf_utils import calculate_explanation_innvestigate import numpy as np method_lower = method_name.lower() # Handle specialized LRP sign variants if any(method_lower.startswith(prefix) for prefix in ['lrpsign_', 'lrpz_', 'flatlrp_', 'w2lrp_', 'zblrp_']): return self._execute_sign_variant_tensorflow(model, x, method_lower, **kwargs) # Handle other complex methods - fallback to original wrappers else: raise NotImplementedError("Use original wrappers for complex LRP methods") except Exception as e: logger.warning(f"SpecializedLRPFamily TensorFlow failed for {method_name}: {e}") raise
def _execute_sign_variant_tensorflow(self, model, x, method_lower, **kwargs): """Execute LRP sign variants for TensorFlow.""" from ..tf_signxai.tf_utils import calculate_explanation_innvestigate import numpy as np # Parse the method components rule_type, rule_params, modifiers = self._parse_specialized_lrp_method(method_lower) # Get base LRP result if rule_type == 'lrpsign': # Use base epsilon or alpha_beta rule if 'epsilon' in rule_params: method_for_innvestigate = 'lrp.epsilon' method_kwargs = {'epsilon': rule_params['epsilon']} elif 'alpha' in rule_params and 'beta' in rule_params: method_for_innvestigate = 'lrp.alpha_beta' method_kwargs = {'alpha': rule_params['alpha'], 'beta': rule_params['beta']} else: method_for_innvestigate = 'lrp.epsilon' method_kwargs = {'epsilon': 0.25} elif rule_type == 'lrpz': if 'sequential' in modifiers: composite_type = 'composite_a' if 'composite_a' in modifiers else 'composite_b' method_for_innvestigate = f'lrp.sequential_{composite_type}' method_kwargs = {} elif 'alpha' in rule_params: method_for_innvestigate = 'lrp.alpha_beta' method_kwargs = {'alpha': rule_params['alpha'], 'beta': rule_params['beta']} else: method_for_innvestigate = 'lrp.z' method_kwargs = {} elif rule_type == 'flatlrp': method_for_innvestigate = 'lrp.flat' method_kwargs = {} elif rule_type == 'w2lrp': method_for_innvestigate = 'lrp.w_square' method_kwargs = {} elif rule_type == 'zblrp': # Z-box is VGG16-specific, use regular LRP as fallback if 'sequential' in modifiers: composite_type = 'composite_a' if 'composite_a' in modifiers else 'composite_b' method_for_innvestigate = f'lrp.sequential_{composite_type}' method_kwargs = {} elif 'alpha' in rule_params: method_for_innvestigate = 'lrp.alpha_beta' method_kwargs = {'alpha': rule_params['alpha'], 'beta': rule_params['beta']} else: method_for_innvestigate = 'lrp.epsilon' method_kwargs = {'epsilon': rule_params.get('epsilon', 0.01)} else: # Fallback method_for_innvestigate = 'lrp.epsilon' method_kwargs = {'epsilon': 0.01} # Execute base LRP method result = calculate_explanation_innvestigate( model, x, method=method_for_innvestigate, neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **method_kwargs ) # Apply modifiers if rule_type == 'lrpsign': # Apply sign modifier if 'mu' in modifiers: mu = modifiers['mu'] from ..tf_signxai.methods_impl.signed import calculate_sign_mu result = result * calculate_sign_mu(x, mu) else: result = result * np.sign(x) if 'std_x' in modifiers: std = np.std(x) if std > 0: result = result / std return result def _parse_specialized_lrp_method(self, method_lower): """Parse specialized LRP method names to extract components.""" modifiers = {} # Check for std_x modifier if 'std_x' in method_lower: modifiers['std_x'] = True # Check for mu modifiers if '_mu_' in method_lower: if 'mu_0_5' in method_lower: modifiers['mu'] = 0.5 elif 'mu_neg_0_5' in method_lower: modifiers['mu'] = -0.5 else: modifiers['mu'] = 0.0 # Check for sequential composite if 'sequential' in method_lower: modifiers['sequential'] = True if 'composite_a' in method_lower: modifiers['composite_a'] = True elif 'composite_b' in method_lower: modifiers['composite_b'] = True # Determine rule type if method_lower.startswith('lrpsign_'): rule_type = 'lrpsign' elif method_lower.startswith('lrpz_'): rule_type = 'lrpz' elif method_lower.startswith('flatlrp_'): rule_type = 'flatlrp' elif method_lower.startswith('w2lrp_'): rule_type = 'w2lrp' elif method_lower.startswith('zblrp_'): rule_type = 'zblrp' else: rule_type = 'unknown' # Parse parameters rule_params = {} # Parse epsilon values if 'epsilon_' in method_lower: epsilon_idx = method_lower.find('epsilon_') epsilon_start = epsilon_idx + 8 # Length of 'epsilon_' # Find end of epsilon value remaining = method_lower[epsilon_start:] epsilon_str = '' for part in remaining.split('_'): if part and part not in ['std', 'x', 'mu', 'neg', 'sequential', 'composite', 'a', 'b']: try: float(part.replace('_', '.')) epsilon_str += part + '_' except ValueError: break else: break if epsilon_str: epsilon_str = epsilon_str.rstrip('_').replace('_', '.') try: rule_params['epsilon'] = float(epsilon_str) except ValueError: rule_params['epsilon'] = 0.25 # Parse alpha/beta values if 'alpha_' in method_lower and 'beta_' in method_lower: alpha_idx = method_lower.find('alpha_') beta_idx = method_lower.find('beta_') if alpha_idx != -1 and beta_idx != -1: # Extract alpha alpha_start = alpha_idx + 6 # Length of 'alpha_' alpha_part = method_lower[alpha_start:beta_idx-1] try: rule_params['alpha'] = float(alpha_part.replace('_', '.')) except ValueError: rule_params['alpha'] = 1.0 # Extract beta beta_start = beta_idx + 5 # Length of 'beta_' beta_part = method_lower[beta_start:].split('_')[0] try: rule_params['beta'] = float(beta_part.replace('_', '.')) except ValueError: rule_params['beta'] = 0.0 return rule_type, rule_params, modifiers
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute specialized LRP methods for PyTorch with comprehensive parsing.""" try: import torch import numpy as np method_lower = method_name.lower() # Handle specialized LRP sign variants if any(method_lower.startswith(prefix) for prefix in ['lrpsign_', 'lrpz_', 'flatlrp_', 'w2lrp_', 'zblrp_']): return self._execute_sign_variant_pytorch(model, x, method_lower, **kwargs) # Handle other complex methods - use original implementation else: from ..torch_signxai.methods_impl.zennit_impl.analyzers import AdvancedLRPAnalyzer # Map to the appropriate analyzer if 'flat' in method_lower: analyzer = AdvancedLRPAnalyzer(model, 'flat', **kwargs) elif 'w_square' in method_lower or 'wsquare' in method_lower or 'w2' in method_lower: analyzer = AdvancedLRPAnalyzer(model, 'wsquare', **kwargs) elif 'gamma' in method_lower: # Extract gamma value if present gamma = 0.25 # Default if 'gamma_' in method_lower: parts = method_lower.split('gamma_')[1].split('_') if parts[0]: try: gamma = float(parts[0].replace('_', '.')) except ValueError: pass analyzer = AdvancedLRPAnalyzer(model, 'gamma', gamma=gamma, **kwargs) else: # Fallback to advanced analyzer analyzer = AdvancedLRPAnalyzer(model, method_lower, **kwargs) result = analyzer.analyze(x, kwargs.get('target_class')) # Handle modifiers if 'x_input' in method_lower: x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x result = result * x_np if 'x_sign' in method_lower: x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x result = result * np.sign(x_np) if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() return result except Exception as e: logger.warning(f"SpecializedLRPFamily failed for {method_name}: {e}") raise
def _execute_sign_variant_pytorch(self, model, x, method_lower, **kwargs): """Execute LRP sign variants for PyTorch.""" import torch import numpy as np from ..torch_signxai.methods_impl.zennit_impl.analyzers import LRPAnalyzer # Parse the method components rule_type, rule_params, modifiers = self._parse_specialized_lrp_method(method_lower) # Get base LRP result based on rule type if rule_type == 'lrpsign': # Use base epsilon or alpha_beta rule if 'epsilon' in rule_params: analyzer = LRPAnalyzer(model, 'epsilon', rule_params['epsilon']) elif 'alpha' in rule_params and 'beta' in rule_params: analyzer = LRPAnalyzer( model, 'alphabeta', alpha=rule_params['alpha'], beta=rule_params['beta'] ) else: analyzer = LRPAnalyzer(model, 'epsilon', 0.25) elif rule_type == 'lrpz': if 'sequential' in modifiers: try: from ..torch_signxai.methods_impl.zennit_impl.analyzers import AdvancedLRPAnalyzer composite_type = 'composite_a' if 'composite_a' in modifiers else 'composite_b' analyzer = AdvancedLRPAnalyzer(model, composite_type) except ImportError: analyzer = LRPAnalyzer(model, 'zplus') elif 'alpha' in rule_params: analyzer = LRPAnalyzer( model, 'alphabeta', alpha=rule_params['alpha'], beta=rule_params['beta'] ) else: analyzer = LRPAnalyzer(model, 'zplus') elif rule_type == 'flatlrp': try: from ..torch_signxai.methods_impl.zennit_impl.analyzers import AdvancedLRPAnalyzer analyzer = AdvancedLRPAnalyzer(model, 'flat') except ImportError: analyzer = LRPAnalyzer(model, 'epsilon', 0.01) elif rule_type == 'w2lrp': try: from ..torch_signxai.methods_impl.zennit_impl.analyzers import AdvancedLRPAnalyzer analyzer = AdvancedLRPAnalyzer(model, 'wsquare') except ImportError: analyzer = LRPAnalyzer(model, 'epsilon', 0.01) elif rule_type == 'zblrp': # Z-box is VGG16-specific, use regular LRP as fallback if 'sequential' in modifiers: try: from ..torch_signxai.methods_impl.zennit_impl.analyzers import AdvancedLRPAnalyzer composite_type = 'composite_a' if 'composite_a' in modifiers else 'composite_b' analyzer = AdvancedLRPAnalyzer(model, composite_type) except ImportError: analyzer = LRPAnalyzer(model, 'alphabeta', alpha=1.0, beta=0.0) elif 'alpha' in rule_params: analyzer = LRPAnalyzer( model, 'alphabeta', alpha=rule_params['alpha'], beta=rule_params['beta'] ) else: analyzer = LRPAnalyzer(model, 'epsilon', rule_params.get('epsilon', 0.01)) else: # Fallback analyzer = LRPAnalyzer(model, 'epsilon', 0.01) # Execute base method result = analyzer.analyze(x, kwargs.get('target_class')) # Ensure numpy array if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() # Apply modifiers if rule_type == 'lrpsign': # Apply sign modifier x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x if 'mu' in modifiers: mu = modifiers['mu'] from ..torch_signxai.methods_impl.signed import calculate_sign_mu result = result * calculate_sign_mu(x_np, mu) else: result = result * np.sign(x_np) if 'std_x' in modifiers: x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x std = np.std(x_np) if std > 0: result = result / std return result
[docs] class DeepLiftFamily(MethodFamily): """ Handles DeepLift and related methods. """
[docs] def __init__(self): super().__init__() self.supported_methods = {'deeplift', 'deep_lift', 'deeplift_rescale'}
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is a DeepLift method.""" method_lower = method_name.lower() return 'deeplift' in method_lower or 'deep_lift' in method_lower
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute DeepLift for TensorFlow.""" try: from ..tf_signxai.tf_utils import calculate_explanation_innvestigate result = calculate_explanation_innvestigate( model, x, method='deeplift.wrapper', neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **kwargs ) return result except Exception as e: logger.warning(f"DeepLiftFamily failed for {method_name}: {e}") raise
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute DeepLift for PyTorch.""" try: from ..torch_signxai.methods_impl.zennit_impl.analyzers import DeepLiftAnalyzer import torch analyzer = DeepLiftAnalyzer(model) result = analyzer.analyze(x, kwargs.get('target_class')) if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() return result except Exception as e: logger.warning(f"DeepLiftFamily failed for {method_name}: {e}") raise
[docs] class GuidedFamily(MethodFamily): """ Handles guided backprop, deconvnet and related methods. """
[docs] def __init__(self): super().__init__() self.supported_methods = { 'guided_backprop', 'deconvnet', 'guided_grad_cam', 'guided_backprop_x_input', 'guided_backprop_x_sign', 'deconvnet_x_input', 'deconvnet_x_sign' }
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is a guided method.""" method_lower = method_name.lower() # Exclude guided_grad_cam methods - they should be handled by CAMFamily if 'guided_grad_cam' in method_lower or 'grad_cam' in method_lower: return False return any(method_lower.startswith(m) for m in ['guided', 'deconvnet'])
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute guided methods for TensorFlow with dynamic modifiers.""" try: from ..tf_signxai.tf_utils import calculate_explanation_innvestigate import numpy as np method_lower = method_name.lower() # Determine base method if 'guided' in method_lower or method_lower.startswith('guided'): base_method = 'guided_backprop' elif 'deconv' in method_lower or method_lower.startswith('deconv'): base_method = 'deconvnet' else: base_method = 'guided_backprop' # Get base result result = calculate_explanation_innvestigate( model, x, method=base_method, neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **{k: v for k, v in kwargs.items() if k not in ['target_class', 'neuron_selection', 'modifier']} ) # Apply modifiers dynamically modifiers = kwargs.get('modifier', '') if '_x_input' in method_lower or 'input' in modifiers: result = result * x if '_x_sign' in method_lower or 'sign' in modifiers: # Use mu from kwargs if available (parsed by MethodParser) if 'mu' in kwargs: mu = kwargs.get('mu', 0.0) from ..tf_signxai.methods_impl.signed import calculate_sign_mu result = result * calculate_sign_mu(x, mu) else: result = result * np.sign(x) if 'std_x' in method_lower or 'std' in modifiers: std = np.std(x) if std > 0: result = result / std return result except Exception as e: logger.warning(f"GuidedFamily failed for {method_name}: {e}") raise
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute guided methods for PyTorch.""" try: import torch import numpy as np method_lower = method_name.lower() if 'guided_backprop' in method_lower: from ..torch_signxai.methods_impl.zennit_impl.analyzers import GuidedBackpropAnalyzer analyzer = GuidedBackpropAnalyzer(model) elif 'deconvnet' in method_lower: from ..torch_signxai.methods_impl.zennit_impl.analyzers import DeconvNetAnalyzer analyzer = DeconvNetAnalyzer(model) else: raise ValueError(f"Unknown guided method: {method_name}") result = analyzer.analyze(x, kwargs.get('target_class')) # Apply modifiers if 'x_input' in method_lower: x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x result = result * x_np if 'x_sign' in method_lower: x_np = x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else x # Use mu from kwargs if available (parsed by MethodParser) if 'mu' in kwargs: mu = kwargs.get('mu', 0.0) from ..torch_signxai.methods_impl.signed import calculate_sign_mu result = result * calculate_sign_mu(x_np, mu) else: result = result * np.sign(x_np) if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() return result except Exception as e: logger.warning(f"GuidedFamily failed for {method_name}: {e}") raise
[docs] class CAMFamily(MethodFamily): """ Handles GradCAM and other CAM-based methods. """
[docs] def __init__(self): super().__init__() self.supported_methods = { 'grad_cam', 'gradcam', 'grad_cam_timeseries', 'scorecam', 'layercam', 'xgradcam', 'grad_cam_VGG16ILSVRC' }
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is a CAM method.""" method_lower = method_name.lower() return 'cam' in method_lower or 'grad_cam' in method_lower
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute CAM methods for TensorFlow.""" try: from ..tf_signxai.methods_impl.grad_cam import ( calculate_grad_cam_relevancemap, calculate_grad_cam_relevancemap_timeseries ) # Get the last convolutional layer name - required for Grad-CAM last_conv_layer_name = kwargs.get('last_conv_layer_name') # Handle VGG16-specific methods if 'VGG16ILSVRC' in method_name: # Use VGG16-specific layer name last_conv_layer_name = 'block5_conv3' # Standard VGG16 last conv layer # If not provided, try to find the last conv layer automatically if last_conv_layer_name is None: for layer in reversed(model.layers): if 'conv' in layer.name.lower(): last_conv_layer_name = layer.name break if last_conv_layer_name is None: raise ValueError("No convolutional layer found and last_conv_layer_name not specified") # Remove the parameter so we don't pass it twice kwargs_copy = kwargs.copy() kwargs_copy.pop('last_conv_layer_name', None) # Determine if timeseries or image based on input shape if x.ndim <= 3 or 'timeseries' in method_name.lower(): return calculate_grad_cam_relevancemap_timeseries( x, model, last_conv_layer_name=last_conv_layer_name, neuron_selection=kwargs_copy.get('target_class', kwargs_copy.get('neuron_selection')), **kwargs_copy ) else: return calculate_grad_cam_relevancemap( x, model, last_conv_layer_name=last_conv_layer_name, neuron_selection=kwargs_copy.get('target_class', kwargs_copy.get('neuron_selection')), **kwargs_copy ) except Exception as e: logger.warning(f"CAMFamily failed for {method_name}: {e}") raise
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute CAM methods for PyTorch.""" try: from ..torch_signxai.methods_impl.grad_cam import ( calculate_grad_cam_relevancemap, calculate_grad_cam_relevancemap_timeseries ) import torch import numpy as np # Remove target_class from kwargs to avoid duplicate kwargs_copy = kwargs.copy() target_class = kwargs_copy.pop('target_class', None) # Handle VGG16-specific methods if 'VGG16ILSVRC' in method_name: # Use VGG16-specific layer for PyTorch # For VGG models, features[28] is the last conv layer (Conv2d) if hasattr(model, 'features') and len(model.features) > 28: kwargs_copy['target_layer'] = model.features[28] # PyTorch VGG16 last conv layer else: # Let it auto-detect pass # Handle guided_grad_cam if 'guided' in method_name.lower(): # Guided Grad-CAM combines Grad-CAM with guided backprop # Calculate regular Grad-CAM first if x.dim() <= 3 or 'timeseries' in method_name.lower(): grad_cam_result = calculate_grad_cam_relevancemap_timeseries( model, x, target_class=target_class, **kwargs_copy ) else: grad_cam_result = calculate_grad_cam_relevancemap( model, x, target_class=target_class, **kwargs_copy ) # Calculate guided backpropagation from ..torch_signxai.methods_impl.guided import GuidedBackprop guided_bp = GuidedBackprop(model) guided_result = guided_bp.attribute(x, target=target_class) # Convert to numpy if tensor if isinstance(guided_result, torch.Tensor): guided_result = guided_result.detach().cpu().numpy() # Element-wise multiplication (this is the core of Guided Grad-CAM) # Resize grad_cam to match guided_result shape if needed if grad_cam_result.shape != guided_result.shape: # For image inputs, grad_cam is typically (H, W) while guided is (C, H, W) if len(guided_result.shape) == 3 and len(grad_cam_result.shape) == 2: # Expand grad_cam to match channels grad_cam_result = np.expand_dims(grad_cam_result, axis=0) grad_cam_result = np.repeat(grad_cam_result, guided_result.shape[0], axis=0) elif len(guided_result.shape) == 4 and len(grad_cam_result.shape) == 3: # Batch case: expand grad_cam to match batch and channels grad_cam_result = np.expand_dims(grad_cam_result, axis=1) grad_cam_result = np.repeat(grad_cam_result, guided_result.shape[1], axis=1) # Element-wise multiplication result = grad_cam_result * guided_result if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() return result # Determine if timeseries or image if x.dim() <= 3 or 'timeseries' in method_name.lower(): result = calculate_grad_cam_relevancemap_timeseries( model, x, target_class=target_class, **kwargs_copy ) else: result = calculate_grad_cam_relevancemap( model, x, target_class=target_class, **kwargs_copy ) if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() return result except Exception as e: logger.warning(f"CAMFamily failed for {method_name}: {e}") raise
[docs] class OcclusionFamily(MethodFamily): """ Handles occlusion-based methods. """
[docs] def __init__(self): super().__init__() self.supported_methods = {'occlusion', 'occlusion_sensitivity'}
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is an occlusion method.""" return 'occlusion' in method_name.lower()
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute occlusion for TensorFlow.""" try: from ..tf_signxai.methods_impl.occlusion import calculate_occlusion_relevancemap return calculate_occlusion_relevancemap( x, model, neuron_selection=kwargs.get('target_class', kwargs.get('neuron_selection')), **kwargs ) except Exception as e: logger.warning(f"OcclusionFamily failed for {method_name}: {e}") raise
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute occlusion for PyTorch.""" try: from ..torch_signxai.methods_impl.occlusion import calculate_occlusion_relevancemap import torch result = calculate_occlusion_relevancemap( model, x, target_class=kwargs.get('target_class'), **kwargs ) if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy() return result except Exception as e: logger.warning(f"OcclusionFamily failed for {method_name}: {e}") raise
[docs] class RandomFamily(MethodFamily): """ Handles random baseline methods. """
[docs] def __init__(self): super().__init__() self.supported_methods = {'random', 'random_uniform'}
[docs] def can_handle(self, method_name: str) -> bool: """Check if this is a random method.""" return 'random' in method_name.lower()
[docs] def execute_tensorflow(self, model, x, method_name: str, **kwargs): """Execute random for TensorFlow.""" import numpy as np return np.random.uniform(-1, 1, size=x.shape)
[docs] def execute_pytorch(self, model, x, method_name: str, **kwargs): """Execute random for PyTorch.""" import numpy as np import torch if isinstance(x, torch.Tensor): shape = x.shape else: shape = x.shape return np.random.uniform(-1, 1, size=shape)
[docs] class MethodFamilyRegistry: """ Registry that manages all method families and routes requests. """
[docs] def __init__(self): self.families = [] self.fallback_handler = None self._initialize_families() # Initialize the method parser for dynamic parameter extraction from .method_parser import MethodParser self.parser = MethodParser()
def _initialize_families(self): """Initialize method families based on environment configuration.""" # Check if we should use ALL families (new default) use_all = os.environ.get('SIGNXAI_USE_ALL_FAMILIES', 'true').lower() == 'true' if use_all or os.environ.get('SIGNXAI_USE_SIMPLE_GRADIENT', 'true').lower() == 'true': self.families.append(SimpleGradientFamily()) if use_all or os.environ.get('SIGNXAI_USE_STOCHASTIC', 'true').lower() == 'true': self.families.append(StochasticMethodFamily()) if use_all or os.environ.get('SIGNXAI_USE_LRP_BASIC', 'true').lower() == 'true': self.families.append(LRPBasicFamily()) if use_all or os.environ.get('SIGNXAI_USE_SPECIALIZED_LRP', 'true').lower() == 'true': self.families.append(SpecializedLRPFamily()) if use_all or os.environ.get('SIGNXAI_USE_DEEPLIFT', 'true').lower() == 'true': self.families.append(DeepLiftFamily()) if use_all or os.environ.get('SIGNXAI_USE_GUIDED', 'true').lower() == 'true': self.families.append(GuidedFamily()) if use_all or os.environ.get('SIGNXAI_USE_CAM', 'true').lower() == 'true': self.families.append(CAMFamily()) if use_all or os.environ.get('SIGNXAI_USE_OCCLUSION', 'true').lower() == 'true': self.families.append(OcclusionFamily()) if use_all or os.environ.get('SIGNXAI_USE_RANDOM', 'true').lower() == 'true': self.families.append(RandomFamily())
[docs] def execute(self, model, x, method_name: str, framework: str, **kwargs): """ Execute a method by finding the appropriate family. Falls back to original wrappers if no family can handle it. """ # Parse the method name to extract parameters and modifiers parsed = self.parser.parse(method_name) base_method = parsed['base_method'] extracted_params = parsed['params'] modifiers = parsed['modifiers'] # Merge extracted parameters with kwargs (kwargs take precedence) for key, value in extracted_params.items(): if key not in kwargs: kwargs[key] = value # Add modifiers to kwargs for families to use if modifiers: kwargs['_modifiers'] = modifiers # Log parsed information for debugging logger.debug(f"Parsed method '{method_name}': base='{base_method}', params={extracted_params}, modifiers={modifiers}") # Try each family in order with the base method for family in self.families: # Check if family can handle either original name or base method if family.can_handle(method_name) or family.can_handle(base_method): try: # Pass original method name so families can do their own parsing if needed return family.execute(model, x, method_name, framework, **kwargs) except Exception as e: logger.info(f"Family {family.__class__.__name__} failed, trying next: {e}") continue # Fallback to original wrappers return self._fallback_to_wrappers(model, x, method_name, framework, **kwargs)
def _fallback_to_wrappers(self, model, x, method_name: str, framework: str, **kwargs): """Fallback to original implementations if wrappers exist.""" logger.debug(f"No family can handle {method_name}, trying fallback") if framework == 'tensorflow': try: from ..tf_signxai.methods_impl.wrappers import calculate_relevancemap # Remove target_class from kwargs since we pass it as neuron_selection kwargs_copy = kwargs.copy() target_class = kwargs_copy.pop('target_class', None) return calculate_relevancemap( method_name, x, model, neuron_selection=target_class, **kwargs_copy ) except ImportError: # If wrappers don't exist, method is not supported raise NotImplementedError(f"Method '{method_name}' is not implemented for TensorFlow") elif framework == 'pytorch': try: from ..torch_signxai.methods_impl.wrappers import calculate_relevancemap # Remove target_class from kwargs since we pass it explicitly kwargs_copy = kwargs.copy() target_class = kwargs_copy.pop('target_class', None) return calculate_relevancemap( model=model, input_tensor=x, method=method_name, target_class=target_class, **kwargs_copy ) except ImportError: # Try zennit_impl as final fallback try: from ..torch_signxai.methods_impl.zennit_impl import calculate_relevancemap as zennit_calc kwargs_copy = kwargs.copy() target_class = kwargs_copy.pop('target_class', None) return zennit_calc( model=model, input_tensor=x, method=method_name, target_class=target_class, **kwargs_copy ) except: raise NotImplementedError(f"Method '{method_name}' is not implemented for PyTorch") else: raise ValueError(f"Unsupported framework: {framework}")
[docs] def get_supported_methods(self) -> set: """Get methods that are ACTUALLY supported by BOTH frameworks. Returns the intersection of methods that work in both TensorFlow and PyTorch, not the union. This ensures only truly comparable methods are returned. """ # Get methods that are actually implemented in each framework tensorflow_methods = self._get_all_tensorflow_method_variations() pytorch_methods = self._get_all_pytorch_method_variations() # Get methods from registered families (as a supplement) family_methods = set() for family in self.families: family_methods.update(family.supported_methods) # Find intersection - methods that work in BOTH frameworks common_methods = tensorflow_methods & pytorch_methods print(f"Method Family Registry Discovery:") print(f"- TensorFlow methods: {len(tensorflow_methods)}") print(f"- PyTorch methods: {len(pytorch_methods)}") print(f"- Family methods: {len(family_methods)}") print(f"- Common (intersection): {len(common_methods)}") # Filter out non-string methods and special markers filtered_methods = {m for m in common_methods if isinstance(m, str) and not m.startswith('tf_exact_') and m != "WrapperDelegation"} print(f"- Final filtered methods: {len(filtered_methods)}") return filtered_methods
def _get_all_pytorch_method_variations(self) -> set: """Get ACTUAL PyTorch methods that have real implementations.""" # Get core methods that are actually implemented in PyTorch/Zennit try: from ..torch_signxai.methods_impl.zennit_impl import SUPPORTED_ZENNIT_METHODS all_zennit_methods = set(SUPPORTED_ZENNIT_METHODS.keys()) # Filter out special markers and wrapper delegations core_methods = {m for m in all_zennit_methods if isinstance(m, str) and not m.startswith('tf_exact_') and m != "WrapperDelegation" and SUPPORTED_ZENNIT_METHODS[m] != "WrapperDelegation"} print(f"PyTorch core implemented methods: {len(core_methods)}") # Only return methods that are in common families and actually work # Focus on methods that have counterparts in TensorFlow common_implementable_methods = { # Core gradient methods 'gradient', 'gradient_x_input', 'gradient_x_sign', 'gradient_x_input_x_sign', 'gradient_x_sign_mu', 'gradient_x_sign_mu_0', 'gradient_x_sign_mu_0_5', 'gradient_x_sign_mu_neg_0_5', 'input_t_gradient', # Stochastic methods 'smoothgrad', 'smoothgrad_x_input', 'smoothgrad_x_sign', 'smoothgrad_x_input_x_sign', 'smoothgrad_x_sign_mu', 'smoothgrad_x_sign_mu_0', 'smoothgrad_x_sign_mu_0_5', 'smoothgrad_x_sign_mu_neg_0_5', 'vargrad', 'vargrad_x_input', 'vargrad_x_sign', 'vargrad_x_input_x_sign', 'integrated_gradients', 'integrated_gradients_x_input', 'integrated_gradients_x_sign', 'integrated_gradients_x_input_x_sign', # Guided methods 'guided_backprop', 'guided_backprop_x_input', 'guided_backprop_x_sign', 'guided_backprop_x_input_x_sign', 'guided_backprop_x_sign_mu', 'guided_backprop_x_sign_mu_0', 'guided_backprop_x_sign_mu_0_5', 'guided_backprop_x_sign_mu_neg_0_5', 'deconvnet', 'deconvnet_x_input', 'deconvnet_x_sign', 'deconvnet_x_input_x_sign', 'deconvnet_x_sign_mu', 'deconvnet_x_sign_mu_0', 'deconvnet_x_sign_mu_0_5', 'deconvnet_x_sign_mu_neg_0_5', # DeepLift - commented out as it doesn't work properly # 'deep_lift', 'deeplift', # LRP core methods that work in both frameworks - INCLUDING ALL MODIFIERS 'lrp', 'lrp_epsilon', 'lrp_z', 'lrp_gamma', 'lrp_flat', 'lrp_w_square', 'lrp_alpha_1_beta_0', 'lrp_alpha_2_beta_1', # LRP methods with x_input, x_sign, x_input_x_sign modifiers 'lrp_x_input', 'lrp_x_sign', 'lrp_x_input_x_sign', 'lrp_epsilon_x_input', 'lrp_epsilon_x_sign', 'lrp_epsilon_x_input_x_sign', 'lrp_z_x_input', 'lrp_z_x_sign', 'lrp_z_x_input_x_sign', 'lrp_gamma_x_input', 'lrp_gamma_x_sign', 'lrp_gamma_x_input_x_sign', 'lrp_flat_x_input', 'lrp_flat_x_sign', 'lrp_flat_x_input_x_sign', 'lrp_w_square_x_input', 'lrp_w_square_x_sign', 'lrp_w_square_x_input_x_sign', 'lrp_alpha_1_beta_0_x_input', 'lrp_alpha_1_beta_0_x_sign', 'lrp_alpha_1_beta_0_x_input_x_sign', 'lrp_alpha_2_beta_1_x_input', 'lrp_alpha_2_beta_1_x_sign', 'lrp_alpha_2_beta_1_x_input_x_sign', 'lrp_z_plus_x_input', 'lrp_z_plus_x_sign', 'lrp_z_plus_x_input_x_sign', # All LRP epsilon variations 'lrp_epsilon_0_001', 'lrp_epsilon_0_01', 'lrp_epsilon_0_1', 'lrp_epsilon_0_2', 'lrp_epsilon_0_25', 'lrp_epsilon_0_5', 'lrp_epsilon_1', 'lrp_epsilon_2', 'lrp_epsilon_5', 'lrp_epsilon_10', 'lrp_epsilon_20', 'lrp_epsilon_50', 'lrp_epsilon_75', 'lrp_epsilon_100', # LRP epsilon with std_x 'lrp_epsilon_0_1_std_x', 'lrp_epsilon_0_25_std_x', 'lrp_epsilon_0_5_std_x', 'lrp_epsilon_1_std_x', 'lrp_epsilon_2_std_x', 'lrp_epsilon_3_std_x', # LRP Sign variations 'lrpsign_z', 'lrpsign_epsilon_0_001', 'lrpsign_epsilon_0_01', 'lrpsign_epsilon_0_1', 'lrpsign_epsilon_0_2', 'lrpsign_epsilon_0_5', 'lrpsign_epsilon_1', 'lrpsign_epsilon_5', 'lrpsign_epsilon_10', 'lrpsign_epsilon_20', 'lrpsign_epsilon_50', 'lrpsign_epsilon_75', 'lrpsign_epsilon_100', 'lrpsign_epsilon_100_mu_0', 'lrpsign_epsilon_100_mu_0_5', 'lrpsign_epsilon_100_mu_neg_0_5', 'lrpsign_epsilon_0_1_std_x', 'lrpsign_epsilon_0_25_std_x', 'lrpsign_epsilon_0_25_std_x_mu_0', 'lrpsign_epsilon_0_25_std_x_mu_0_5', 'lrpsign_epsilon_0_25_std_x_mu_neg_0_5', 'lrpsign_epsilon_0_5_std_x', 'lrpsign_epsilon_1_std_x', 'lrpsign_epsilon_2_std_x', 'lrpsign_epsilon_3_std_x', 'lrpsign_alpha_1_beta_0', 'lrpsign_sequential_composite_a', 'lrpsign_sequential_composite_b', # LRP Z variations 'lrpz_epsilon_0_001', 'lrpz_epsilon_0_01', 'lrpz_epsilon_0_1', 'lrpz_epsilon_0_2', 'lrpz_epsilon_0_5', 'lrpz_epsilon_1', 'lrpz_epsilon_5', 'lrpz_epsilon_10', 'lrpz_epsilon_20', 'lrpz_epsilon_50', 'lrpz_epsilon_75', 'lrpz_epsilon_100', 'lrpz_epsilon_0_1_std_x', 'lrpz_epsilon_0_25_std_x', 'lrpz_epsilon_0_5_std_x', 'lrpz_epsilon_1_std_x', 'lrpz_epsilon_2_std_x', 'lrpz_epsilon_3_std_x', 'lrpz_sequential_composite_a', 'lrpz_sequential_composite_b', # Flat LRP variations 'flatlrp_z', 'flatlrp_epsilon_0_01', 'flatlrp_epsilon_0_1', 'flatlrp_epsilon_1', 'flatlrp_epsilon_10', 'flatlrp_epsilon_20', 'flatlrp_epsilon_100', 'flatlrp_epsilon_0_1_std_x', 'flatlrp_epsilon_0_25_std_x', 'flatlrp_epsilon_0_5_std_x', 'flatlrp_sequential_composite_a', 'flatlrp_sequential_composite_b', # W2 LRP variations 'w2lrp_z', 'w2lrp_epsilon_0_01', 'w2lrp_epsilon_0_1', 'w2lrp_epsilon_1', 'w2lrp_epsilon_10', 'w2lrp_epsilon_20', 'w2lrp_epsilon_100', 'w2lrp_epsilon_0_1_std_x', 'w2lrp_epsilon_0_25_std_x', 'w2lrp_epsilon_0_5_std_x', 'w2lrp_sequential_composite_a', 'w2lrp_sequential_composite_b', # VGG16ILSVRC specific (ZB-LRP with ImageNet bounds) 'zblrp_z_VGG16ILSVRC', 'zblrp_epsilon_0_001_VGG16ILSVRC', 'zblrp_epsilon_0_01_VGG16ILSVRC', 'zblrp_epsilon_0_1_VGG16ILSVRC', 'zblrp_epsilon_0_2_VGG16ILSVRC', 'zblrp_epsilon_0_5_VGG16ILSVRC', 'zblrp_epsilon_1_VGG16ILSVRC', 'zblrp_epsilon_5_VGG16ILSVRC', 'zblrp_epsilon_10_VGG16ILSVRC', 'zblrp_epsilon_20_VGG16ILSVRC', 'zblrp_epsilon_100_VGG16ILSVRC', 'zblrp_epsilon_0_1_std_x_VGG16ILSVRC', 'zblrp_epsilon_0_25_std_x_VGG16ILSVRC', 'zblrp_epsilon_0_5_std_x_VGG16ILSVRC', 'zblrp_sequential_composite_a_VGG16ILSVRC', 'zblrp_sequential_composite_b_VGG16ILSVRC', # LRP with dot notation (for iNNvestigate compatibility) 'lrp', 'lrp.epsilon', 'lrp.z', 'lrp.gamma', 'lrp.flat', 'lrp.w_square', 'lrp.alpha_1_beta_0', 'lrp.alpha_2_beta_1', 'lrp.alpha_beta', 'lrp.sequential_composite_a', 'lrp.sequential_composite_b', 'lrp.z_plus', 'lrp.z_plus_fast', 'lrp.stdxepsilon', 'lrp.alpha_1_beta_0_IB', 'lrp.alpha_2_beta_1_IB', 'lrp.epsilon_IB', 'lrp.z_IB', 'lrp.rule_until_index', # 'deeplift.wrapper', # Commented out as DeepLift doesn't work # CAM methods 'grad_cam', 'grad_cam_x_input', 'grad_cam_x_sign', 'grad_cam_x_input_x_sign', 'grad_cam_VGG16ILSVRC', 'grad_cam_timeseries', # Others that work 'random_uniform', 'occlusion' } # Instead of filtering, return ALL methods we've defined as implementable # since we know these work through the Method Families final_methods = common_implementable_methods print(f"PyTorch final common methods: {len(final_methods)}") return final_methods except ImportError as e: print(f"Could not import SUPPORTED_ZENNIT_METHODS: {e}") # Fallback to basic methods that definitely work return { 'gradient', 'guided_backprop', 'deconvnet', 'smoothgrad', 'integrated_gradients', 'grad_cam' } def _get_all_tensorflow_method_variations(self) -> set: """Get ACTUAL TensorFlow methods from local iNNvestigate copy.""" # Get methods that are actually implemented in TensorFlow/iNNvestigate # Based on actual analyzers available in tf_signxai/methods/innvestigate/analyzer/__init__.py innvestigate_analyzers = { # Core gradient methods 'gradient', 'input_t_gradient', 'deconvnet', 'guided_backprop', 'integrated_gradients', 'smoothgrad', 'vargrad', # Gradient with x_input and x_sign variations 'gradient_x_input', 'gradient_x_sign', 'gradient_x_input_x_sign', 'gradient_x_sign_mu', 'gradient_x_sign_mu_0', 'gradient_x_sign_mu_0_5', 'gradient_x_sign_mu_neg_0_5', # Guided backprop variations 'guided_backprop_x_input', 'guided_backprop_x_sign', 'guided_backprop_x_input_x_sign', 'guided_backprop_x_sign_mu', 'guided_backprop_x_sign_mu_0', 'guided_backprop_x_sign_mu_0_5', 'guided_backprop_x_sign_mu_neg_0_5', # Deconvnet variations 'deconvnet_x_input', 'deconvnet_x_sign', 'deconvnet_x_input_x_sign', 'deconvnet_x_sign_mu', 'deconvnet_x_sign_mu_0', 'deconvnet_x_sign_mu_0_5', 'deconvnet_x_sign_mu_neg_0_5', # Smoothgrad variations 'smoothgrad_x_input', 'smoothgrad_x_sign', 'smoothgrad_x_input_x_sign', 'smoothgrad_x_sign_mu', 'smoothgrad_x_sign_mu_0', 'smoothgrad_x_sign_mu_0_5', 'smoothgrad_x_sign_mu_neg_0_5', # Integrated gradients variations 'integrated_gradients_x_input', 'integrated_gradients_x_sign', 'integrated_gradients_x_input_x_sign', # VarGrad variations 'vargrad_x_input', 'vargrad_x_sign', 'vargrad_x_input_x_sign', # DeepLift - commented out as it doesn't work in TensorFlow # 'deep_lift', 'deeplift', 'deeplift.wrapper', # Core LRP methods that actually exist (use dot notation for iNNvestigate) 'lrp', 'lrp.z', 'lrp.z_IB', 'lrp.gamma', 'lrp.epsilon', 'lrp.stdxepsilon', 'lrp.epsilon_IB', 'lrp.w_square', 'lrp.flat', 'lrp.alpha_beta', 'lrp.alpha_2_beta_1', 'lrp.alpha_2_beta_1_IB', 'lrp.alpha_1_beta_0', 'lrp.alpha_1_beta_0_IB', 'lrp.z_plus', 'lrp.z_plus_fast', 'lrp.sequential_composite_a', 'lrp.sequential_composite_b', 'lrp.rule_until_index', # Add underscore notation for method families (these get mapped to dot notation) 'lrp_epsilon', 'lrp_z', 'lrp_gamma', 'lrp_flat', 'lrp_w_square', 'lrp_alpha_1_beta_0', 'lrp_alpha_2_beta_1', # All epsilon variations 'lrp_epsilon_0_001', 'lrp_epsilon_0_01', 'lrp_epsilon_0_1', 'lrp_epsilon_0_2', 'lrp_epsilon_0_25', 'lrp_epsilon_0_5', 'lrp_epsilon_1', 'lrp_epsilon_2', 'lrp_epsilon_5', 'lrp_epsilon_10', 'lrp_epsilon_20', 'lrp_epsilon_50', 'lrp_epsilon_75', 'lrp_epsilon_100', # LRP epsilon with std_x 'lrp_epsilon_0_1_std_x', 'lrp_epsilon_0_25_std_x', 'lrp_epsilon_0_5_std_x', 'lrp_epsilon_1_std_x', 'lrp_epsilon_2_std_x', 'lrp_epsilon_3_std_x', # LRP Sign variations 'lrpsign_z', 'lrpsign_epsilon_0_001', 'lrpsign_epsilon_0_01', 'lrpsign_epsilon_0_1', 'lrpsign_epsilon_0_2', 'lrpsign_epsilon_0_5', 'lrpsign_epsilon_1', 'lrpsign_epsilon_5', 'lrpsign_epsilon_10', 'lrpsign_epsilon_20', 'lrpsign_epsilon_50', 'lrpsign_epsilon_75', 'lrpsign_epsilon_100', 'lrpsign_epsilon_100_mu_0', 'lrpsign_epsilon_100_mu_0_5', 'lrpsign_epsilon_100_mu_neg_0_5', 'lrpsign_epsilon_0_1_std_x', 'lrpsign_epsilon_0_25_std_x', 'lrpsign_epsilon_0_25_std_x_mu_0', 'lrpsign_epsilon_0_25_std_x_mu_0_5', 'lrpsign_epsilon_0_25_std_x_mu_neg_0_5', 'lrpsign_epsilon_0_5_std_x', 'lrpsign_epsilon_1_std_x', 'lrpsign_epsilon_2_std_x', 'lrpsign_epsilon_3_std_x', 'lrpsign_alpha_1_beta_0', 'lrpsign_sequential_composite_a', 'lrpsign_sequential_composite_b', # LRP Z variations 'lrpz_epsilon_0_001', 'lrpz_epsilon_0_01', 'lrpz_epsilon_0_1', 'lrpz_epsilon_0_2', 'lrpz_epsilon_0_5', 'lrpz_epsilon_1', 'lrpz_epsilon_5', 'lrpz_epsilon_10', 'lrpz_epsilon_20', 'lrpz_epsilon_50', 'lrpz_epsilon_75', 'lrpz_epsilon_100', 'lrpz_epsilon_0_1_std_x', 'lrpz_epsilon_0_25_std_x', 'lrpz_epsilon_0_5_std_x', 'lrpz_epsilon_1_std_x', 'lrpz_epsilon_2_std_x', 'lrpz_epsilon_3_std_x', 'lrpz_sequential_composite_a', 'lrpz_sequential_composite_b', # Flat LRP variations 'flatlrp_z', 'flatlrp_epsilon_0_01', 'flatlrp_epsilon_0_1', 'flatlrp_epsilon_1', 'flatlrp_epsilon_10', 'flatlrp_epsilon_20', 'flatlrp_epsilon_100', 'flatlrp_epsilon_0_1_std_x', 'flatlrp_epsilon_0_25_std_x', 'flatlrp_epsilon_0_5_std_x', 'flatlrp_sequential_composite_a', 'flatlrp_sequential_composite_b', # W2 LRP variations 'w2lrp_z', 'w2lrp_epsilon_0_01', 'w2lrp_epsilon_0_1', 'w2lrp_epsilon_1', 'w2lrp_epsilon_10', 'w2lrp_epsilon_20', 'w2lrp_epsilon_100', 'w2lrp_epsilon_0_1_std_x', 'w2lrp_epsilon_0_25_std_x', 'w2lrp_epsilon_0_5_std_x', 'w2lrp_sequential_composite_a', 'w2lrp_sequential_composite_b' } # Add custom SignXAI TensorFlow implementations custom_tf_methods = { 'grad_cam', 'grad_cam_timeseries', 'occlusion', 'grad_cam_x_input', 'grad_cam_x_sign', 'grad_cam_x_input_x_sign', 'grad_cam_VGG16ILSVRC', 'random_uniform' } # Add VGG16ILSVRC specific methods (ZB-LRP with ImageNet bounds) vgg16_specific = { 'zblrp_z_VGG16ILSVRC', 'zblrp_epsilon_0_001_VGG16ILSVRC', 'zblrp_epsilon_0_01_VGG16ILSVRC', 'zblrp_epsilon_0_1_VGG16ILSVRC', 'zblrp_epsilon_0_2_VGG16ILSVRC', 'zblrp_epsilon_0_5_VGG16ILSVRC', 'zblrp_epsilon_1_VGG16ILSVRC', 'zblrp_epsilon_5_VGG16ILSVRC', 'zblrp_epsilon_10_VGG16ILSVRC', 'zblrp_epsilon_20_VGG16ILSVRC', 'zblrp_epsilon_100_VGG16ILSVRC', 'zblrp_epsilon_0_1_std_x_VGG16ILSVRC', 'zblrp_epsilon_0_25_std_x_VGG16ILSVRC', 'zblrp_epsilon_0_5_std_x_VGG16ILSVRC', 'zblrp_sequential_composite_a_VGG16ILSVRC', 'zblrp_sequential_composite_b_VGG16ILSVRC' } # Only include variations that we can actually implement via Method Families # These are modifiers we can apply to base methods implementable_variations = set() # Apply x_input, x_sign modifiers to ALL attribution methods that support them # This includes gradient methods, LRP methods, and other attribution methods base_methods_for_modifiers = { # Gradient-based methods 'gradient', 'guided_backprop', 'deconvnet', 'smoothgrad', 'vargrad', 'integrated_gradients', # Core LRP methods - these should ALL support modifiers 'lrp', 'lrp_z', 'lrp_gamma', 'lrp_epsilon', 'lrp_flat', 'lrp_w_square', 'lrp_alpha_1_beta_0', 'lrp_alpha_2_beta_1', # LRP z variations 'lrp_z_plus', # GradCAM 'grad_cam' } for base_method in base_methods_for_modifiers: if base_method in innvestigate_analyzers or base_method in custom_tf_methods: implementable_variations.add(f'{base_method}_x_input') implementable_variations.add(f'{base_method}_x_sign') implementable_variations.add(f'{base_method}_x_input_x_sign') # Combine all actual methods methods = innvestigate_analyzers | custom_tf_methods | vgg16_specific | implementable_variations print(f"TensorFlow actual methods: {len(methods)} total") print(f"TF Core methods: {len(innvestigate_analyzers)}, Custom: {len(custom_tf_methods)}, Variations: {len(implementable_variations)}") return methods
# Global registry instance _registry = None
[docs] def get_registry() -> MethodFamilyRegistry: """Get or create the global registry instance.""" global _registry if _registry is None: _registry = MethodFamilyRegistry() return _registry