Source code for signxai.utils.utils

"""
Framework-agnostic utility functions for SignXAI2.

This module contains utility functions that work independently of any deep learning framework.
Framework-specific utilities have been moved to:
- signxai/tf_signxai/tf_utils.py for TensorFlow utilities
- signxai/torch_signxai/torch_utils.py for PyTorch utilities
"""

import os
import sys
import numpy as np
import requests
from PIL import Image


[docs] def get_examples_data_dir(): """ Get the path to the examples/data directory regardless of current working directory Returns: str: Path to the examples/data directory """ # Try to find examples/data directory current_dir = os.getcwd() # Check if we're in examples/tutorials/tensorflow if os.path.basename(current_dir) == 'tensorflow' and \ os.path.basename(os.path.dirname(current_dir)) == 'tutorials': return os.path.join(os.path.dirname(os.path.dirname(current_dir)), 'data') # Check if we're in examples/tutorials if os.path.basename(current_dir) == 'tutorials' and \ os.path.basename(os.path.dirname(current_dir)) == 'examples': return os.path.join(os.path.dirname(current_dir), 'data') # Check if we're in examples if os.path.basename(current_dir) == 'examples': return os.path.join(current_dir, 'data') # Check if we're in project root if os.path.exists(os.path.join(current_dir, 'examples', 'data')): return os.path.join(current_dir, 'examples', 'data') # Last resort: try to find examples/data relative to script location script_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) examples_data_dir = os.path.join(script_dir, 'examples', 'data') if os.path.exists(examples_data_dir): return examples_data_dir # If we couldn't find it, create it in the current directory os.makedirs(os.path.join(current_dir, 'data'), exist_ok=True) return os.path.join(current_dir, 'data')
[docs] def aggregate_and_normalize_relevancemap_rgb(relevancemap): """ Aggregate and normalize a RGB relevance map Args: relevancemap: RGB relevance map Returns: Normalized relevance map """ # Aggregate channels if relevancemap.ndim == 3: relevancemap = np.sum(relevancemap, axis=2) return normalize_heatmap(relevancemap)
[docs] def normalize_heatmap(heatmap): """ Normalize a heatmap to the range [-1, 1] Args: heatmap: Heatmap to normalize Returns: Normalized heatmap """ if heatmap.min() != heatmap.max(): max_abs = np.max(np.abs(heatmap)) if max_abs > 0: heatmap = heatmap / max_abs return np.nan_to_num(heatmap, nan=0) else: return np.zeros_like(heatmap)
# Deprecated functions with framework-specific imports # These are kept for backward compatibility but will delegate to framework-specific modules
[docs] def remove_softmax(model): """ Remove the softmax activation from the last layer of a model. This function delegates to the appropriate framework-specific implementation. Args: model: TensorFlow or PyTorch model Returns: Model with softmax removed """ # Try to determine the framework model_type = type(model).__module__ if 'tensorflow' in model_type or 'keras' in model_type: from ..tf_signxai.tf_utils import remove_softmax as tf_remove_softmax return tf_remove_softmax(model) elif 'torch' in model_type: from ..torch_signxai.torch_utils import remove_softmax as torch_remove_softmax return torch_remove_softmax(model) else: raise ValueError(f"Unable to determine framework for model type: {type(model)}")
[docs] def calculate_explanation_innvestigate(model, x, method='lrp.epsilon', neuron_selection=None, batchmode=False, **kwargs): """ Calculate an explanation using the innvestigate backend (TensorFlow only). This function has been moved to tf_utils.py as it's TensorFlow-specific. This wrapper is kept for backward compatibility. """ from ..tf_signxai.tf_utils import calculate_explanation_innvestigate as tf_calc_explanation return tf_calc_explanation(model, x, method, neuron_selection, batchmode, **kwargs)
[docs] def load_image(img_path, target_size=(224, 224), expand_dims=False, use_original_preprocessing=True): """ Load an image from a file path and preprocess it. This function has been moved to tf_utils.py as it uses TensorFlow preprocessing. This wrapper is kept for backward compatibility. """ from ..tf_signxai.tf_utils import load_image as tf_load_image return tf_load_image(img_path, target_size, expand_dims, use_original_preprocessing)
[docs] def download_image(path): """ Download example image if it doesn't exist Args: path: Path to save the image """ if not os.path.exists(path): # Create directory if it doesn't exist directory = os.path.dirname(path) if not os.path.exists(directory): os.makedirs(directory) # Example image URL url = "https://raw.githubusercontent.com/nilsgumpfer/SIGN-experiment-resources/main/example.jpg" response = requests.get(url) with open(path, 'wb') as f: f.write(response.content)
[docs] def download_model(path): """ Download example model if it doesn't exist Args: path: Path to save the model """ if not os.path.exists(path): # Create directory if it doesn't exist directory = os.path.dirname(path) if not os.path.exists(directory): os.makedirs(directory) # Example model URL url = "https://raw.githubusercontent.com/nilsgumpfer/SIGN-experiment-resources/main/model.h5" response = requests.get(url) with open(path, 'wb') as f: f.write(response.content)