Source code for signxai.common.method_parser

# signxai/common/method_parser.py
import re
from typing import Dict, Any, List, Optional, Tuple


[docs] class MethodParser: """ Parses XAI method strings into a structured format with comprehensive parameter extraction. Handles complex method names like: - lrp_alpha_2_beta_1 -> lrp with alpha=2, beta=1 - gradient_x_sign_mu_neg_0_5 -> gradient with sign modifier and mu=-0.5 - lrp_epsilon_0_25_std_x -> lrp_epsilon with epsilon=0.25 and std_x modifier """
[docs] def parse(self, method_name: str) -> Dict[str, Any]: """ Parses a method string into base method, modifiers, and parameters. Args: method_name (str): The method string. Returns: A dictionary with: - base_method: The canonical base method name - modifiers: List of modifiers (x_input, x_sign, std_x, etc.) - params: Extracted parameters with proper types - original_name: The original method name """ original = method_name method_lower = method_name.lower() # First, check for special compound methods base_method, params = self._extract_compound_method(method_lower) # Extract modifiers modifiers = self._extract_modifiers(method_lower) # Extract additional parameters not covered by compound extraction additional_params = self._extract_parameters(method_lower) params.update(additional_params) return { 'base_method': base_method, 'modifiers': modifiers, 'params': params, 'original_name': original }
def _extract_compound_method(self, method: str) -> Tuple[str, Dict[str, Any]]: """Extract base method from compound patterns like lrp_alpha_X_beta_Y.""" params = {} # LRP Alpha-Beta pattern (e.g., lrp_alpha_2_beta_1) alpha_beta_match = re.match(r'^([\w]+?)_alpha_(\d+)_beta_(\d+)', method) if alpha_beta_match: prefix = alpha_beta_match.group(1) params['alpha'] = float(alpha_beta_match.group(2)) params['beta'] = float(alpha_beta_match.group(3)) # Special handling for specific alpha-beta combinations if params['alpha'] == 1 and params['beta'] == 0: return f"{prefix}_alpha_1_beta_0", params elif params['alpha'] == 2 and params['beta'] == 1: return f"{prefix}_alpha_2_beta_1", params else: return f"{prefix}_alpha_beta", params # LRP Sequential Composite patterns if 'sequential_composite_a' in method: base = method.split('_sequential_composite_a')[0] return f"{base}_sequential_composite_a", params elif 'sequential_composite_b' in method: base = method.split('_sequential_composite_b')[0] return f"{base}_sequential_composite_b", params # Default: extract first component as base parts = method.split('_') base_method = parts[0] # Handle multi-part base methods if len(parts) > 1: # Check for known multi-part bases if parts[0] in ['lrp', 'lrpsign', 'lrpz', 'flatlrp', 'w2lrp', 'zblrp']: if parts[1] in ['epsilon', 'gamma', 'z', 'flat', 'w', 'alpha', 'sequential']: if parts[1] == 'w' and len(parts) > 2 and parts[2] == 'square': base_method = f"{parts[0]}_w_square" elif parts[1] == 'sequential': # Already handled above pass elif parts[1] in ['epsilon', 'gamma'] and len(parts) > 2: # Extract parameter value for epsilon/gamma methods base_method = f"{parts[0]}_{parts[1]}" # Try to extract the parameter value param_parts = [] i = 2 while i < len(parts) and parts[i] not in ['x', 'times', 'std', 'ib', 'timeseries']: param_parts.append(parts[i]) i += 1 if param_parts: # Convert underscores back to decimal points param_str = '_'.join(param_parts) if len(param_parts) == 1: params[parts[1]] = float(param_parts[0]) elif len(param_parts) == 2: params[parts[1]] = float(f"{param_parts[0]}.{param_parts[1]}") else: base_method = f"{parts[0]}_{parts[1]}" elif parts[0] == 'integrated' and len(parts) > 1 and parts[1] == 'gradients': base_method = 'integrated_gradients' elif parts[0] == 'guided' and len(parts) > 1 and parts[1] == 'backprop': base_method = 'guided_backprop' elif parts[0] == 'grad' and len(parts) > 1 and parts[1] == 'cam': base_method = 'grad_cam' elif parts[0] == 'deep' and len(parts) > 1 and (parts[1] == 'lift' or parts[1] == 'taylor'): base_method = f"{parts[0]}_{parts[1]}" elif parts[0] == 'input' and len(parts) > 1 and parts[1] == 't' and len(parts) > 2 and parts[2] == 'gradient': base_method = 'input_t_gradient' return base_method, params def _extract_modifiers(self, method: str) -> List[str]: """Extract modifiers like x_input, x_sign, std_x, etc.""" modifiers = [] # Check for input modifier if '_x_input' in method or '_times_input' in method: modifiers.append('x_input') # Check for sign modifier if '_x_sign' in method: modifiers.append('x_sign') # Check for std_x modifier if '_std_x' in method: modifiers.append('std_x') # Check for IB (ignore bias) modifier if '_ib' in method.lower() or method.endswith('_ib'): modifiers.append('ignore_bias') # Check for timeseries modifier if '_timeseries' in method: modifiers.append('timeseries') return modifiers def _extract_parameters(self, method: str) -> Dict[str, Any]: """Extract numerical parameters from method name.""" params = {} # Epsilon values (e.g., epsilon_0_1 -> 0.1, epsilon_0_25 -> 0.25) epsilon_match = re.search(r'epsilon_(\d+)(?:_(\d+))?(?![_\d])', method) if epsilon_match and 'alpha' not in method: # Avoid matching in alpha_beta patterns whole = int(epsilon_match.group(1)) decimal = epsilon_match.group(2) if decimal: params['epsilon'] = float(f"{whole}.{decimal}") else: params['epsilon'] = float(whole) # Mu values (e.g., mu_0_5 -> 0.5, mu_neg_0_5 -> -0.5) mu_match = re.search(r'mu_(neg_)?(\d+)(?:_(\d+))?(?![_\d])', method) if mu_match: is_negative = bool(mu_match.group(1)) whole = int(mu_match.group(2)) decimal = mu_match.group(3) if decimal: value = float(f"{whole}.{decimal}") else: value = float(whole) params['mu'] = -value if is_negative else value # Gamma values gamma_match = re.search(r'gamma_(\d+)(?:_(\d+))?(?![_\d])', method) if gamma_match: whole = int(gamma_match.group(1)) decimal = gamma_match.group(2) if decimal: params['gamma'] = float(f"{whole}.{decimal}") else: params['gamma'] = float(whole) # Steps for integrated gradients steps_match = re.search(r'steps_(\d+)', method) if steps_match: params['steps'] = int(steps_match.group(1)) # Noise level for smoothgrad/vargrad noise_match = re.search(r'noise_(\d+)(?:_(\d+))?', method) if noise_match: whole = int(noise_match.group(1)) decimal = noise_match.group(2) if decimal: params['noise_level'] = float(f"{whole}.{decimal}") else: params['noise_level'] = float(whole) # Number of samples samples_match = re.search(r'(?:samples|num_samples)_(\d+)', method) if samples_match: params['num_samples'] = int(samples_match.group(1)) # Stdfactor for LRP methods stdfactor_match = re.search(r'stdfactor_(\d+)(?:_(\d+))?', method) if stdfactor_match: whole = int(stdfactor_match.group(1)) decimal = stdfactor_match.group(2) if decimal: params['stdfactor'] = float(f"{whole}.{decimal}") else: params['stdfactor'] = float(whole) return params def _is_param(self, part: str) -> bool: """ Checks if a part of the method string is a parameter name. """ return part in ['epsilon', 'mu', 'alpha', 'beta', 'steps', 'noise_level', 'num_samples', 'stdfactor', 'gamma', 'noise', 'samples']