Source code for signxai

# signxai/__init__.py
__version__ = "0.14.1"

_DEFAULT_BACKEND = None
_AVAILABLE_BACKENDS = []

# Module placeholders
tf_signxai = None
torch_signxai = None


# Lazy loading functions to avoid circular imports
def _load_tf_signxai():
    """Lazy loader for TensorFlow SignXAI module."""
    global tf_signxai
    if tf_signxai is None:
        try:
            import tensorflow
            import signxai.tf_signxai as tf_module
            tf_signxai = tf_module
            if "tensorflow" not in _AVAILABLE_BACKENDS:
                _AVAILABLE_BACKENDS.append("tensorflow")
        except ImportError:
            pass
    return tf_signxai


def _load_torch_signxai():
    """Lazy loader for PyTorch SignXAI module."""
    global torch_signxai
    if torch_signxai is None:
        try:
            import torch
            import zennit  # Required for PyTorch LRP methods
            import signxai.torch_signxai as torch_module
            torch_signxai = torch_module
            if "pytorch" not in _AVAILABLE_BACKENDS:
                _AVAILABLE_BACKENDS.append("pytorch")
        except ImportError:
            pass
    return torch_signxai


# Attempt immediate loading to populate _AVAILABLE_BACKENDS
# Check PyTorch first to make it the default when both are available
try:
    import torch
    import zennit

    _load_torch_signxai()
    if not _DEFAULT_BACKEND:
        _DEFAULT_BACKEND = "pytorch"
except ImportError:
    pass

try:
    import tensorflow

    _load_tf_signxai()
    if not _DEFAULT_BACKEND:
        _DEFAULT_BACKEND = "tensorflow"
except ImportError:
    pass


# Helper functions for API (defined here to avoid circular imports)
def _detect_framework(model):
    """Detect which framework a model belongs to."""
    # Check TensorFlow
    try:
        import tensorflow as tf
        if isinstance(model, (tf.keras.Model, tf.keras.Sequential)) or hasattr(model, 'predict'):
            return 'tensorflow'
    except ImportError:
        pass

    # Check PyTorch
    try:
        import torch
        if isinstance(model, torch.nn.Module):
            return 'pytorch'
    except ImportError:
        pass

    return None


def _prepare_model(model, framework):
    """Prepare model for explanation (remove softmax if needed)."""
    if framework == 'tensorflow':
        from signxai.tf_signxai.tf_utils import remove_softmax
        return remove_softmax(model)
    else:  # pytorch
        from signxai.torch_signxai.torch_utils import remove_softmax
        model_copy = model.__class__(**{k: v for k, v in model.__dict__.items() if not k.startswith('_')})
        model_copy.load_state_dict(model.state_dict())
        return remove_softmax(model_copy)


def _prepare_input(x, framework):
    """Prepare input data for the specified framework."""
    import numpy as np

    if framework == 'tensorflow':
        # Ensure numpy array for TensorFlow
        if hasattr(x, 'detach'):  # PyTorch tensor
            x = x.detach().cpu().numpy()
        elif not isinstance(x, np.ndarray):
            x = np.array(x)
        return x
    else:  # pytorch
        # Ensure PyTorch tensor
        import torch
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        elif not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        return x


def _get_predicted_class(model, x, framework):
    """Get the predicted class from the model."""
    import numpy as np

    if framework == 'tensorflow':
        preds = model.predict(x, verbose=0)
        return int(np.argmax(preds[0]))
    else:  # pytorch
        import torch
        model.eval()
        with torch.no_grad():
            preds = model(x)
        return int(torch.argmax(preds, dim=1).item())


def _map_parameters(method_name, framework, **kwargs):
    """Map parameters between frameworks for method compatibility."""
    mapped = kwargs.copy()

    # Common parameter mappings
    param_mapping = {
        'integrated_gradients': {
            'tensorflow': {'reference_inputs': 'baseline', 'steps': 'steps'},
            'pytorch': {'baseline': 'reference_inputs', 'ig_steps': 'steps'}
        },
        'smoothgrad': {
            'tensorflow': {'augment_by_n': 'num_samples', 'noise_scale': 'noise_level'},
            'pytorch': {'num_samples': 'augment_by_n', 'noise_level': 'noise_scale'}
        },
        'gradcam': {
            'tensorflow': {'layer_name': 'layer_name'},
            'pytorch': {'target_layer': 'layer_name'}
        }
    }

    if method_name in param_mapping:
        target_mapping = param_mapping[method_name].get(framework, {})
        for new_key, old_key in target_mapping.items():
            if old_key in kwargs:
                mapped[new_key] = mapped.pop(old_key)

    return mapped


def _call_tensorflow_method(model, x, method_name, target_class, **kwargs):
    """Call TensorFlow implementation using the new architecture."""
    tf_module = _load_tf_signxai()
    if not tf_module:
        _check_framework_availability()
    
    # Use the new method family architecture
    from signxai.common.method_families import get_registry
    from signxai.common.method_parser import MethodParser
    
    # Try method families first
    try:
        registry = get_registry()
        return registry.execute(
            model=model,
            x=x,
            method_name=method_name,
            framework='tensorflow',
            target_class=target_class,
            neuron_selection=target_class,
            **kwargs
        )
    except Exception as e:
        # Fallback to direct execution
        from signxai.tf_signxai.methods import execute as tf_execute
        parser = MethodParser()
        parsed_method = parser.parse(method_name)
        return tf_execute(
            model=model,
            x=x,
            parsed_method=parsed_method,
            target_class=target_class,
            neuron_selection=target_class,
            **kwargs
        )


def _call_pytorch_method(model, x, method_name, target_class, **kwargs):
    """Call PyTorch implementation using the new architecture."""
    torch_module = _load_torch_signxai()
    if not torch_module:
        _check_framework_availability()
    
    # Use the new method family architecture
    from signxai.common.method_families import get_registry
    from signxai.common.method_parser import MethodParser
    
    # Try method families first
    try:
        registry = get_registry()
        return registry.execute(
            model=model,
            x=x,
            method_name=method_name,
            framework='pytorch',
            target_class=target_class,
            **kwargs
        )
    except Exception as e:
        # Fallback to direct execution
        from signxai.torch_signxai.methods import execute as pt_execute
        parser = MethodParser()
        parsed_method = parser.parse(method_name)
        return pt_execute(
            model=model,
            x=x,
            parsed_method=parsed_method,
            target_class=target_class,
            **kwargs
        )


# Legacy framework-specific imports (for backwards compatibility)
def _framework_specific_import_required(*args, **kwargs):
    msg = ("Use the unified API: from signxai import explain\n"
           "Or framework-specific imports:\n"
           "  TensorFlow: from signxai.tf_signxai import calculate_relevancemap\n"
           "  PyTorch: from signxai.torch_signxai import calculate_relevancemap")
    raise ImportError(msg)


calculate_relevancemap = _framework_specific_import_required
calculate_relevancemaps = _framework_specific_import_required


# Check if any framework is available
def _check_framework_availability():
    """Check if at least one framework is available and provide helpful error if not."""
    if not _AVAILABLE_BACKENDS:
        error_msg = (
                "\n" + "=" * 70 + "\n"
                                  "ERROR: No deep learning framework detected!\n\n"
                                  "SignXAI2 requires at least one framework to be installed.\n"
                                  "You have installed signxai2 without specifying a framework.\n\n"
                                  "Please install SignXAI2 with one of the following options:\n\n"
                                  "  For TensorFlow support:\n"
                                  "    pip install signxai2[tensorflow]\n\n"
                                  "  For PyTorch support:\n"
                                  "    pip install signxai2[pytorch]\n\n"
                                  "  For both frameworks:\n"
                                  "    pip install signxai2[all]\n\n"
                                  "  For development (includes all frameworks + dev tools):\n"
                                  "    pip install signxai2[dev]\n\n"
                                  "Note: Python 3.9 or 3.10 is required.\n"
                                  "=" * 70 + "\n"
        )
        raise ImportError(error_msg)


# Import API functions for convenience
try:
    from .api import (
        explain as _explain_impl,
        list_methods as _list_methods_impl,
        get_method_info as _get_method_info_impl,
        explain_with_preset as _explain_with_preset_impl,
        METHOD_PRESETS
    )

    _API_AVAILABLE = True
except ImportError as e:
    print(f"Warning: Could not import unified API: {e}")
    _API_AVAILABLE = False


# Create wrapper functions that check framework availability
[docs] def explain(*args, **kwargs): """Wrapper for explain that checks framework availability.""" if not _AVAILABLE_BACKENDS: _check_framework_availability() if not _API_AVAILABLE: raise ImportError("SignXAI API is not available. Please check your installation.") return _explain_impl(*args, **kwargs)
[docs] def list_methods(*args, **kwargs): """Wrapper for list_methods that checks framework availability.""" if not _AVAILABLE_BACKENDS: _check_framework_availability() if not _API_AVAILABLE: raise ImportError("SignXAI API is not available. Please check your installation.") return _list_methods_impl(*args, **kwargs)
[docs] def get_method_info(*args, **kwargs): """Wrapper for get_method_info that checks framework availability.""" if not _AVAILABLE_BACKENDS: _check_framework_availability() if not _API_AVAILABLE: raise ImportError("SignXAI API is not available. Please check your installation.") return _get_method_info_impl(*args, **kwargs)
[docs] def explain_with_preset(*args, **kwargs): """Wrapper for explain_with_preset that checks framework availability.""" if not _AVAILABLE_BACKENDS: _check_framework_availability() if not _API_AVAILABLE: raise ImportError("SignXAI API is not available. Please check your installation.") return _explain_with_preset_impl(*args, **kwargs)
# Dynamically build __all__ __all__ = ['__version__', '_DEFAULT_BACKEND', '_AVAILABLE_BACKENDS', 'calculate_relevancemap', 'calculate_relevancemaps'] # Add API functions if available if _API_AVAILABLE: __all__.extend(['explain', 'list_methods', 'get_method_info', 'explain_with_preset', 'METHOD_PRESETS']) # Add modules to __all__ if available if _load_tf_signxai(): __all__.append('tf_signxai') if _load_torch_signxai(): __all__.append('torch_signxai') # Note: Framework availability is checked when API functions are called