Source code for signxai.common.method_normalizer

"""
Method name normalizer to handle naming inconsistencies between frameworks.

This module provides canonical method naming and aliasing support to ensure
consistent method names across TensorFlow and PyTorch implementations.
"""

from typing import Dict, Optional, Set


[docs] class MethodNormalizer: """Normalize method names across frameworks to handle inconsistencies.""" # Default parameters for methods METHOD_PRESETS: Dict[str, Dict] = { 'smoothgrad': { 'noise_level': 0.1, 'num_samples': 25 }, 'integrated_gradients': { 'steps': 50, 'baseline': None }, 'vargrad': { 'noise_level': 0.1, 'num_samples': 25 }, 'lrp_epsilon': { 'epsilon': 0.1 }, 'lrp_alpha_1_beta_0': { 'alpha': 1.0, 'beta': 0.0 }, 'lrp_alpha_2_beta_1': { 'alpha': 2.0, 'beta': 1.0 } } # Canonical names mapping (aliases -> canonical) # Using underscored versions as the canonical names for proper TensorFlow/PyTorch compatibility CANONICAL_NAMES: Dict[str, str] = { # Integrated gradients variations (integrated_gradients is canonical) 'integratedgradients': 'integrated_gradients', 'integrated_gradient': 'integrated_gradients', 'integratedgradient': 'integrated_gradients', 'integrated_gradients': 'integrated_gradients', # Grad-CAM variations (grad_cam is canonical for consistency) 'gradcam': 'grad_cam', 'grad_cam': 'grad_cam', 'gradCAM': 'grad_cam', 'GradCAM': 'grad_cam', # Gradient variations (keep both for now, they may have different behaviors) 'gradients': 'gradient', # Map plural to singular as default 'gradient': 'gradient', # DeepLift variations 'deeplift_method': 'deeplift', 'deeplift': 'deeplift', 'deep_lift': 'deeplift', # LRP variations 'lrp': 'lrp', 'LRP': 'lrp', # Guided backprop variations 'guided_backprop': 'guided_backprop', 'guidedbackprop': 'guided_backprop', 'guided_backpropagation': 'guided_backprop', # Deconvnet variations 'deconvnet': 'deconvnet', 'deconvolution': 'deconvnet', 'deconv': 'deconvnet', } # Methods that are framework-specific (not available in both) FRAMEWORK_SPECIFIC: Dict[str, Set[str]] = { 'pytorch': { 'lrp_z_x_input', 'lrp_z_x_input_x_sign', 'lrp_z_x_sign', 'lrpsign_epsilon_100_improved', 'lrpsign_epsilon_20_improved', 'flatlrp_z', }, 'tensorflow': set(), # Currently no TF-only methods } # Methods that are known to be broken or disabled DISABLED_METHODS: Set[str] = { 'deconvnet_x_input_DISABLED_BROKEN_WRAPPER', 'deconvnet_x_input_x_sign_DISABLED_BROKEN_WRAPPER', 'deconvnet_x_sign_DISABLED_BROKEN_WRAPPER', 'deconvnet_x_sign_mu_0_5_DISABLED_BROKEN_WRAPPER', 'smoothgrad_x_input_DISABLED_BROKEN_WRAPPER', }
[docs] @classmethod def normalize(cls, method_name: str, framework: Optional[str] = None) -> str: """ Normalize a method name to its canonical form. Args: method_name: The method name to normalize framework: The framework being used ('tensorflow' or 'pytorch') Returns: The canonical method name Raises: ValueError: If the method is disabled or not available in the framework """ # Check if method is disabled if method_name in cls.DISABLED_METHODS: raise ValueError(f"Method '{method_name}' is currently disabled/broken") # Strip any DISABLED suffix if present if '_DISABLED_BROKEN_WRAPPER' in method_name: raise ValueError(f"Method '{method_name}' is marked as broken") # First, check if this is already a canonical name if method_name in cls.CANONICAL_NAMES.values(): canonical = method_name else: # Try to find the base method name (before parameters/transformations) base_method = cls._extract_base_method(method_name) # Check if base method has a canonical form if base_method in cls.CANONICAL_NAMES: # Replace the base with canonical, keep the rest canonical_base = cls.CANONICAL_NAMES[base_method] canonical = method_name.replace(base_method, canonical_base, 1) else: # No mapping found, use as-is canonical = method_name # Check framework availability if specified if framework: framework = framework.lower() if framework in cls.FRAMEWORK_SPECIFIC: # Check if this method is specific to another framework other_frameworks = [fw for fw in cls.FRAMEWORK_SPECIFIC if fw != framework] for other_fw in other_frameworks: if canonical in cls.FRAMEWORK_SPECIFIC[other_fw]: raise ValueError( f"Method '{canonical}' is not available in {framework}, " f"only in {other_fw}" ) return canonical
@classmethod def _extract_base_method(cls, method_name: str) -> str: """ Extract the base method name without parameters or transformations. Args: method_name: Full method name Returns: Base method name """ # Common patterns to remove # Remove x_input, x_sign, x_input_x_sign transformations base = method_name for transform in ['_x_input_x_sign', '_x_input', '_x_sign']: if transform in base: base = base.split(transform)[0] break # Remove parameter values (e.g., _0_1, _0_25, etc.) import re # Remove patterns like _0_1, _0_25, _100, etc. base = re.sub(r'_\d+(_\d+)*$', '', base) # Remove mu parameters base = re.sub(r'_mu(_\d+(_\d+)?)?$', '', base) # Remove model-specific suffixes models = ['_VGG16ILSVRC', '_VGG16', '_ResNet50', '_MNISTCNN'] for model in models: if base.endswith(model): base = base[:-len(model)] break return base
[docs] @classmethod def get_aliases(cls, canonical_name: str) -> Set[str]: """ Get all known aliases for a canonical method name. Args: canonical_name: The canonical method name Returns: Set of all aliases (including the canonical name itself) """ aliases = {canonical_name} for alias, canonical in cls.CANONICAL_NAMES.items(): if canonical == canonical_name: aliases.add(alias) return aliases
[docs] @classmethod def is_framework_specific(cls, method_name: str, framework: str) -> bool: """ Check if a method is specific to a particular framework. Args: method_name: The method name to check framework: The framework to check against Returns: True if the method is specific to the given framework """ framework = framework.lower() canonical = cls.normalize(method_name) return canonical in cls.FRAMEWORK_SPECIFIC.get(framework, set())
[docs] @classmethod def get_framework_methods(cls, framework: str) -> Set[str]: """ Get all framework-specific methods for a given framework. Args: framework: The framework name Returns: Set of framework-specific method names """ return cls.FRAMEWORK_SPECIFIC.get(framework.lower(), set())