Source code for signxai.torch_signxai.methods_family

"""
PyTorch integration for Method Family Architecture.
This module provides an alternative entry point that uses the new family-based approach.
"""

import torch
import numpy as np
import logging
import os

logger = logging.getLogger(__name__)

[docs] def calculate_relevancemap_with_families(model, input_tensor, method: str, target_class=None, **kwargs): """ Calculate relevance map using the Method Family Architecture. Falls back to original wrappers if needed. Args: model: PyTorch model (without softmax) input_tensor: Input tensor method: The XAI method to use target_class: Target class index (optional) **kwargs: Additional method-specific parameters Returns: Relevance map as numpy array """ # Always try families first (they're the default now) use_families = os.environ.get('SIGNXAI_USE_METHOD_FAMILIES', 'true').lower() == 'true' if use_families: try: from ..common.method_families import get_registry registry = get_registry() # Add target_class to kwargs for compatibility (only if not already present) if target_class is not None and 'target_class' not in kwargs: kwargs['target_class'] = target_class logger.info(f"Attempting to execute {method} with Method Family Architecture") result = registry.execute( model=model, x=input_tensor, method_name=method, framework='pytorch', **kwargs ) logger.info(f"Successfully executed {method} with Method Family Architecture") return result except Exception as e: logger.warning(f"Method Family execution failed for {method}: {e}") logger.info("Falling back to original implementation") # Fallback to zennit_impl from .methods.zennit_impl import calculate_relevancemap as zennit_calculate # Remove target_class from kwargs if present to avoid duplicate fallback_kwargs = kwargs.copy() fallback_kwargs.pop('target_class', None) return zennit_calculate( model=model, input_tensor=input_tensor, method=method, target_class=target_class, **fallback_kwargs )