Source code for signxai.torch_signxai

# signxai/torch_signxai/__init__.py

import os
import logging

logger = logging.getLogger(__name__)

# Check if we should use the new Method Family Architecture
USE_METHOD_FAMILIES = os.environ.get('SIGNXAI_USE_METHOD_FAMILIES', 'true').lower() == 'true'

if USE_METHOD_FAMILIES:
    # Use the new Method Family Architecture by default
    logger.info("Using Method Family Architecture for PyTorch")
    from .methods_family import calculate_relevancemap_with_families as calculate_relevancemap
else:
    # Fallback to original implementation
    logger.info("Using original wrapper implementation for PyTorch")
    from .methods.zennit_impl import calculate_relevancemap as calculate_relevancemap

# Import utilities that are commonly used
from .torch_utils import remove_softmax, decode_predictions, NoSoftmaxWrapper

# Import individual method functions for compatibility
try:
    from .methods.wrappers import (
        integrated_gradients,
        grad_cam,
    )
except ImportError:
    # If wrappers are removed, define stub functions
[docs] def integrated_gradients(*args, **kwargs): return calculate_relevancemap(method='integrated_gradients', *args, **kwargs)
[docs] def grad_cam(*args, **kwargs): return calculate_relevancemap(method='grad_cam', *args, **kwargs)
# Define what gets imported with "from signxai.torch_signxai import *" for clarity __all__ = [ "calculate_relevancemap", # This will be the Zennit one "remove_softmax", "decode_predictions", "NoSoftmaxWrapper", # Individual methods for API compatibility "integrated_gradients", "grad_cam", ]