import warnings
import tensorflow as tf
###############################################################################
###############################################################################
###############################################################################
import tensorflow.keras.backend as K
import numpy as np
from ..utils.keras import graph as kgraph
import tensorflow.keras.layers as keras_layers
[docs]
class ReplacementLayer():
"""
Base class for providing explainability functionality.
This class wraps a single network layer, providing hooks for applying the layer and retrieving an explanation.
Basically:
* Any forward passes required for computing the explanation are defined in apply method. During the forward pass, a callback is given to all child layers to retrieve their explanations
* Wrappers (e.g., a GradientTape) around the forward pass(es) that are required to compute an explanation can be defined and returned by wrap_hook method
* In wrap_hook method, forward pass(es) are applied and Tapes defined
* Explanation is computed in explain_hook method and then passed to callback functions of parent ReplacementLayers
:param layer: Layer (of base class tensorflow.keras.layers.Layer) of to wrap
:param layer_next: List of Layers in the network that receive output of layer (=child layers)
:param r_init_constant: If not None, defines a constant output mapping
:param f_init_constant: If not None, defines a constant activation mapping
This is just a base class. To extend this for specific XAI methods,
- wrap_hook()
- explain_hook()
should be overwritten accordingly
"""
[docs]
def __init__(self, layer, layer_next=[], r_init_constant=None, f_init_constant=None, **kwargs):
self.layer_func = layer
self.layer_next = layer_next
self.name = layer.name
self.r_init = r_init_constant
self.f_init = f_init_constant
# Handle input_shape compatibility for modern TensorFlow/Keras
if hasattr(layer, 'input_shape'):
self.input_shape = layer.input_shape
elif hasattr(layer, '_inbound_nodes') and layer._inbound_nodes:
# Modern Keras - get input shapes from inbound nodes
try:
input_shapes = [tensor.shape for tensor in layer._inbound_nodes[0].input_tensors]
self.input_shape = input_shapes[0] if len(input_shapes) == 1 else input_shapes
except (IndexError, AttributeError):
# Fallback to input_spec if available
if hasattr(layer, 'input_spec') and layer.input_spec:
if hasattr(layer.input_spec, 'shape') and layer.input_spec.shape:
self.input_shape = layer.input_spec.shape
else:
# Fallback shape
self.input_shape = (None,)
else:
self.input_shape = (None,)
else:
self.input_shape = (None,)
if not isinstance(self.input_shape, list):
self.input_shape = [self.input_shape]
# Handle output_shape compatibility for modern TensorFlow/Keras
if hasattr(layer, 'output_shape'):
self.output_shape = layer.output_shape
elif hasattr(layer, '_inbound_nodes') and layer._inbound_nodes:
# Modern Keras - get output shapes from inbound nodes
try:
output_shapes = [tensor.shape for tensor in layer._inbound_nodes[0].output_tensors]
self.output_shape = output_shapes[0] if len(output_shapes) == 1 else output_shapes
except (IndexError, AttributeError):
# Fallback shape based on layer type
self.output_shape = (None,)
else:
self.output_shape = (None,)
if not isinstance(self.output_shape, list):
self.output_shape = [self.output_shape]
self.input_vals = None
self.original_output_vals = None
self.reversed_output_vals = None
self.callbacks = None
self.hook_vals = None
self.explanation = None
###############
# TODO: remove one of the variables
self.forward_after_stopping = True
self.base_forward_after_stopping = False
#############
self.reached_after_stop_mapping = None
[docs]
def try_explain(self, reversed_outs):
"""
callback function called by child layers when their explanation is computed.
* aggregates explanations of all children
* calls explain_hook to compute own explanation
* sends own explanation to all parent layers by calling their callback functions
:param reversed_outs: the child layer's explanation. None if this is the last layer.
"""
# aggregate explanations
if reversed_outs is not None:
if self.reversed_output_vals is None:
self.reversed_output_vals = []
self.reversed_output_vals.append(reversed_outs)
# last layer or aggregation finished
if self.reversed_output_vals is None or len(self.reversed_output_vals) == len(self.layer_next):
# apply post hook: explain
if self.hook_vals is None:
raise ValueError(
"self.hook_vals should contain values at this point. Is self.wrap_hook working correctly?")
input_vals = self.input_vals
if len(input_vals) == 1:
input_vals = input_vals[0]
rev_outs = self.reversed_output_vals
if rev_outs is not None:
if len(rev_outs) == 1:
rev_outs = rev_outs[0]
# print("Backward:", self.name)
# print(self.name, len(self.layer_next))
# print(self.name, np.shape(input_vals), np.shape(rev_outs), np.shape(self.hook_vals[0]))
self.explanation = self.explain_hook(input_vals, rev_outs, self.hook_vals)
# callbacks
if self.callbacks is not None:
# check if multiple inputs explained
if len(self.callbacks) > 1 and not isinstance(self.explanation, list):
raise ValueError(self.name + ": This layer has " + str(
len(self.callbacks)) + " inputs, but no list of explanations was provided.")
elif len(self.callbacks) > 1 and len(self.callbacks) != len(self.explanation):
raise ValueError(
self.name + ": This layer has " + str(len(self.callbacks)) + " inputs, but only " + str(
len(self.explanation)) + " explanations were computed")
if len(self.callbacks) > 1:
for c, callback in enumerate(self.callbacks):
callback(self.explanation[c])
else:
self.callbacks[0](self.explanation)
# reset
self.input_vals = None
self.reversed_output_vals = None
self.callbacks = None
self.hook_vals = None
def _forward(self, Ys, neuron_selection=None, stop_mapping_at_layers=None, r_init=None, f_init=None):
"""
Forward Pass to all child layers
* If this is the last layer, directly calls try_explain to compute own explanation
* Otherwise calls try_apply on all child layers
:param Ys: output of own forward pass
:param neuron_selection: neuron_selection parameter (see try_apply)
:param stop_mapping_at_layers: stop_mapping_at_layers parameter (see try_apply)
:param r_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} reverse initialization value. Value with with explanation is initialized (i.e., head_mapping).
:param f_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} forward initialization value. Value with which the forward is initialized.
"""
#print("Forward: ", self.name)
if len(self.layer_next) == 0 :
# last layer: directly compute explanation
self.try_explain(None)
elif stop_mapping_at_layers is not None and self.name in stop_mapping_at_layers:
self.try_explain(None)
#########################
# TODO: New Code, remove comments if it is good
#if self.base_forward_after_stopping:
# for layer_n in self.layer_next:
# layer_n.try_apply(Ys, None, neuron_selection, stop_mapping_at_layers, r_init)
#########################
# TODO: Leander checks if this code is still necessary
if self.forward_after_stopping:
#if the output was mapped, this restores the original for the forwarding
if self.original_output_vals is not None:
Ys = self.original_output_vals
self.original_output_vals = None
#make a dummy callback so that logic does not get buggy
def dummyCallback(reversed_outs):
pass
for layer_n in self.layer_next:
layer_n.try_apply(Ys, dummyCallback, neuron_selection, stop_mapping_at_layers, r_init, f_init)
#############################
else:
# forward
for layer_n in self.layer_next:
layer_n.try_apply(Ys, self.try_explain, neuron_selection, stop_mapping_at_layers, r_init, f_init)
@tf.custom_gradient
def _toNumber(self, X, value):
"""
Helper function to set a Tensor to a fixed value while having a gradient of "value"
"""
y = tf.constant(value, dtype=tf.float32, shape=X.shape)
def grad(dy, variables=None): # variables=None and None as output necessary as toNumber requires two arguments
return dy * tf.ones(X.shape) * value, None
return y, grad
def _neuron_select(self, Ys, neuron_selection):
"""
Performs neuron_selection on Ys
:param Ys: output of own forward pass
:param neuron_selection: neuron_selection parameter (see try_apply)
"""
#error handling is done before, in try_apply
if neuron_selection is None:
Ys = Ys
elif isinstance(neuron_selection, tf.Tensor):
# flatten and then filter neuron_selection index
Ys = tf.reshape(Ys, (Ys.shape[0], np.prod(Ys.shape[1:])))
Ys = tf.gather_nd(Ys, neuron_selection, batch_dims=1)
else:
Ys = K.max(Ys, axis=1, keepdims=True)
return Ys
def _head_mapping(self, Ys, model_output_value=None):
"""
Sets the model output to a fixed value. Used as initialization
for the explanation method.
:param model_output_value: output value of model / initialized value for explanation method
"""
if model_output_value is not None:
if isinstance(model_output_value, dict):
if self.name in model_output_value.keys():
# model_output_value should be int or array-like. Shape should fit.
Ys = self._toNumber(Ys, model_output_value[self.name])
else:
# model_output_value should be int or array-like. Shape should fit.
Ys = self._toNumber(Ys, model_output_value)
return Ys
def _neuron_sel_and_head_map(self, Ys, neuron_selection=None, model_output_value=None):
#save original output
self.original_output_vals = Ys
#apply neuron selection and head mapping
Ys = self._neuron_select(Ys, neuron_selection)
Ys = self._head_mapping(Ys, model_output_value)
return Ys
[docs]
def try_apply(self, ins, callback=None, neuron_selection=None, stop_mapping_at_layers=None, r_init=None, f_init=None):
"""
Tries to apply own forward pass:
* Aggregates inputs and callbacks of all parent layers
* Performs a canonization of the neuron_selection parameter
* Calls wrap_hook (wrapped forward pass(es))
* Calls _forward (forward result of forward pass to child layers)
:param ins output of own forward pass
:param neuron_selection: neuron_selection parameter. One of the following:
- None or "all"
- "max_activation"
- int
- list or np.array of int, with length equal to batch size
:param stop_mapping_at_layers: None or list of layers to stop mapping at
:param callback callback function of the parent layer that called self.try_apply
:param r_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} reverse initialization value. Value with with explanation is initialized (i.e., head_mapping).
:param f_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} forward initialization value. Value with which the forward is initialized.
"""
# DEBUG
# print(self.name, self.input_shape, np.shape(ins))
#uses the class attribute, if it is not None.
if self.r_init is not None:
r_init = self.r_init
if self.f_init is not None:
f_init = self.f_init
# aggregate inputs
if self.input_vals is None:
self.input_vals = []
self.input_vals.append(ins)
# aggregate callbacks
if callback is not None:
if self.callbacks is None:
self.callbacks = []
self.callbacks.append(callback)
# reset explanation
self.explanation = None
# apply layer only if all inputs collected. Then reset inputs
if len(self.input_vals) == len(self.input_shape):
# set inputs to f_init, if it is not None
if f_init is not None:
if isinstance(f_init, dict):
if self.name in f_init.keys():
# f_init should be int or array-like. Shape should fit.
for i, in_val in enumerate(self.input_vals):
self.input_vals[i] = self._toNumber(in_val, f_init[self.name])
else:
# f_init should be int or array-like. Shape should fit.
for i, in_val in enumerate(self.input_vals):
self.input_vals[i] = self._toNumber(in_val, f_init)
# tensorify wrap_hook inputs as much as possible for graph efficiency
input_vals = self.input_vals
if len(input_vals) == 1:
input_vals = input_vals[0]
# adapt neuron_selection param
if len(self.layer_next) == 0 or (stop_mapping_at_layers is not None and self.name in stop_mapping_at_layers):
if neuron_selection is None:
neuron_selection_tmp = None
elif isinstance(neuron_selection, str) and neuron_selection == "all":
neuron_selection_tmp = None
elif isinstance(neuron_selection, str) and neuron_selection == "max_activation":
neuron_selection_tmp = "max_activation"
elif isinstance(neuron_selection, int):
neuron_selection_tmp = [[neuron_selection] for n in range(self.input_vals[0].shape[0])]
neuron_selection_tmp = tf.constant(neuron_selection_tmp)
elif isinstance(neuron_selection, list) or (
hasattr(neuron_selection, "shape") and len(neuron_selection.shape) == 1):
neuron_selection_tmp = [[n] for n in neuron_selection]
neuron_selection_tmp = tf.constant(neuron_selection_tmp)
else:
raise ValueError(
"Parameter neuron_selection only accepts the following values: None, 'all', 'max_activation', <int>, <list>, <one-dimensional array>")
else:
neuron_selection_tmp = neuron_selection
# apply and wrappers
self.hook_vals = self.wrap_hook(input_vals, neuron_selection_tmp, stop_mapping_at_layers, r_init)
# forward
if isinstance(self.hook_vals, tuple):
self._forward(self.hook_vals[0], neuron_selection, stop_mapping_at_layers, r_init, f_init)
else:
self._forward(self.hook_vals, neuron_selection, stop_mapping_at_layers, r_init, f_init)
[docs]
def wrap_hook(self, ins, neuron_selection, stop_mapping_at_layers, r_init):
"""
hook that wraps and applies the layer function.
E.g., by defining a GradientTape
* should contain a call to self._neuron_select.
* may define any wrappers around
:param ins: input(s) of this layer
:param neuron_selection: neuron_selection parameter (see try_apply)
:param stop_mapping_at_layers: None or stop_mapping_at_layers parameter (see try_apply)
:param r_init: reverse initialization value. Value with with explanation is initialized (i.e., head_mapping).
:returns output of layer function + any wrappers that were defined and are needed in explain_hook
To be extended for specific XAI methods
"""
outs = self.layer_func(ins)
# check if final layer (i.e., no next layers)
if len(self.layer_next) == 0 or (stop_mapping_at_layers is not None and self.name in stop_mapping_at_layers):
outs = self._neuron_sel_and_head_map(outs, neuron_selection, r_init)
return outs
[docs]
def explain_hook(self, ins, reversed_outs, args):
"""
hook that computes the explanations.
* Core XAI functionality
:param ins: input(s) of this layer
:param args: outputs of wrap_hook (any parameters that may be needed to compute explanation)
:param reversed_outs: either backpropagated explanation(s) of child layers, or None if this is the last layer
:returns explanation, or tensor of multiple explanations if the layer has multiple inputs (one for each)
To be extended for specific XAI methods
"""
outs = args
if reversed_outs is None:
reversed_outs = outs
if len(self.layer_next) > 1:
#TODO is this addition correct?
ret = keras_layers.Add(dtype=tf.float32)([r for r in reversed_outs])
elif len(self.input_shape) > 1:
ret = [reversed_outs for i in self.input_shape]
ret = tf.keras.layers.concatenate(ret, axis=1)
else:
ret = reversed_outs
return ret
[docs]
class GradientReplacementLayer(ReplacementLayer):
"""
Simple extension of ReplacementLayer
* Explains by computing gradients of outputs w.r.t. inputs of layer
"""
[docs]
def __init__(self, *args, **kwargs):
super(GradientReplacementLayer, self).__init__(*args, **kwargs)
[docs]
def wrap_hook(self, ins, neuron_selection, stop_mapping_at_layers, r_init):
with tf.GradientTape(persistent=True) as tape:
tape.watch(ins)
outs = self.layer_func(ins)
# check if final layer (i.e., no next layers)
if len(self.layer_next) == 0 or (stop_mapping_at_layers is not None and self.name in stop_mapping_at_layers):
outs = self._neuron_sel_and_head_map(outs, neuron_selection=neuron_selection, model_output_value=r_init)
# print('GradientReplacementLayer init:', outs, r_init)
return outs, tape
[docs]
def explain_hook(self, ins, reversed_outs, args):
outs, tape = args
if reversed_outs is None:
reversed_outs = outs
# correct number of outs
if len(self.layer_next) > 1:
outs = [outs for l in self.layer_next]
if len(self.layer_next) > 1:
if len(self.input_shape) > 1:
ret = [keras_layers.Add(dtype=tf.float32)([tape.gradient(o, i, output_gradients=r) for o, r in zip(outs, reversed_outs)]) for i in ins]
else:
ret = keras_layers.Add(dtype=tf.float32)([tape.gradient(o, ins, output_gradients=r) for o, r in zip(outs, reversed_outs)])
else:
if len(self.input_shape) > 1:
ret = [tape.gradient(outs, i, output_gradients=reversed_outs) for i in ins]
else:
ret = tape.gradient(outs, ins, output_gradients=reversed_outs)
return ret
[docs]
class ReverseModel():
"""
Defines a ReverseModel
ReverseModels are built from ReplacementLayer subclasses. A ReverseModel is defined via a list of Input ReplacementLayers (the input layers of the model)
and ReplacementLayers (the whole model)
Offers methods to
- build
- apply
- get precomputed explanations from
- save
- load
the ReverseModel
"""
[docs]
def __init__(self, model, reverse_mappings, default_reverse_mapping, **kwargs):
self.build(model, reverse_mappings, default_reverse_mapping, **kwargs)
[docs]
def build(self, model, reverse_mappings, default_reverse_mapping, **kwargs):
"""
Builds the ReverseModel by wrapping keras network layer(s) into ReplacementLayer(s)
:param model: tf.keras model to be replaced
:param reverse_mappings: mapping layer->reverse mapping (ReplacementLayer or some subclass thereof)
:param default_reverse_mapping: ReplacementLayer or some subclass thereof; default mapping to use
:returns -
"""
# build model that is to be analyzed
layers = kgraph.get_model_layers(model)
# set all replacement layers
replacement_layers = []
for layer in layers:
layer_next = []
wrapper_class = reverse_mappings(layer)
if wrapper_class is None:
wrapper_class = default_reverse_mapping(layer)
if not issubclass(wrapper_class, ReplacementLayer):
raise ValueError("Reverse Mappings should be an instance of ReplacementLayer")
replacement_layers.append(wrapper_class(layer, layer_next, **kwargs))
# connect graph structure
for layer in replacement_layers:
for layer2 in replacement_layers:
inp = layer2.layer_func.input
out = layer.layer_func.output
if not isinstance(inp, list):
inp = [inp]
if not isinstance(out, list):
out = [out]
for i in inp:
if id(i) in [id(o) for o in out] and id(layer) != id(layer2):
layer.layer_next.append(layer2)
# find input access points
input_layers = []
input_layers.append(replacement_layers[0]) # TODO: HOTFIX for input layer identification
# for i, t in enumerate(model.inputs):
# for layer in replacement_layers:
# if id(layer.layer_func.output) == id(t):
# input_layers.append(layer)
# if len(input_layers) < i + 1:
# # if we did not append an input layer, we need to create one
# # TODO case for no input layer here
# raise ValueError("Temporary error. You need to explicitly define an Input Layer for now")
self._reverse_model = (input_layers, replacement_layers)
[docs]
def apply(self, Xs, neuron_selection="max_activation", explained_layer_names=None, stop_mapping_at_layers=None, r_init=None, f_init=None):
"""
Computes an explanation by applying the ReverseModel
:param Xs: tensor or np.array of Input to be explained. Shape (n_ins, batch_size, ...) in model has multiple inputs, or (batch_size, ...) otherwise
:param neuron_selection: neuron_selection parameter. Used to only compute explanation w.r.t. specific output neurons. One of the following:
- None or "all"
- "max_activation"
- int
- list or np.array of int, with length equal to batch size
:param explained_layer_names: None or "all" or list of layer names whose explanations should be returned.
Can be used to obtain intermediate explanations or explanations of multiple layers
:param stop_mapping_at_layers: None or list of layers to stop mapping at ("output" layers)
:param r_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} reverse initialization value. Value with which the explanation is initialized.
:param f_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} forward initialization value. Value with which the forward is initialized.
:returns Dict of the form {layer name (string): explanation (numpy.ndarray)}
"""
# shape of Xs: (n_ins, batch_size, ...), or (batch_size, ...)
reverse_ins, reverse_layers = self._reverse_model
warnings.simplefilter("always")
if stop_mapping_at_layers is not None and (isinstance(neuron_selection, int) or isinstance(neuron_selection, list) or isinstance(neuron_selection, np.ndarray)):
warnings.warn("You are specifying layers to stop forward pass at, and also neuron-selecting by index. Please make sure the corresponding shapes fit together!")
if not isinstance(Xs, tf.Tensor):
try:
Xs = tf.constant(Xs)
except:
raise ValueError("Xs has not supported type ", type(Xs))
# format input & obtain explanations
if len(reverse_ins) == 1:
# single input network
reverse_ins[0].try_apply(tf.constant(Xs), neuron_selection=neuron_selection,
stop_mapping_at_layers=stop_mapping_at_layers, r_init=r_init, f_init=f_init)
else:
# multiple inputs. reshape to (n_ins, batch_size, ...)
for i, reverse_in in enumerate(reverse_ins):
reverse_in.try_apply(tf.constant(Xs[i]), neuron_selection=neuron_selection,
stop_mapping_at_layers=stop_mapping_at_layers, r_init=r_init, f_init=f_init)
# obtain explanations for specified layers
hm = self.get_explanations(explained_layer_names)
return hm
[docs]
def get_explanations(self, explained_layer_names=None):
"""
Get results of (previously computed) explanation.
explanation of layer i has shape equal to input_shape of layer i.
:param explained_layer_names: None or "all" or list of strings containing the names of the layers.
if explained_layer_names == 'all', explanations of all layers are returned.
:returns Dict of the form {layer name (string): explanation (numpy.ndarray)}
"""
reverse_ins, reverse_layers = self._reverse_model
hm = {}
if explained_layer_names is None:
# just explain input layers
for layer in reverse_ins:
hm[layer.name] = layer.explanation.numpy()
return hm
# output everything possible
if explained_layer_names == "all":
for layer in reverse_layers:
if layer.explanation is not None:
hm[layer.name] = layer.explanation.numpy()
return hm
# otherwise, obtain explanations for specified layers
for name in explained_layer_names:
layer = [layer for layer in reverse_layers if layer.name == name]
if len(layer) > 0:
if layer[0].explanation is None:
raise AttributeError(f"layer <<{name}>> has to be analyzed before")
hm[name] = layer[0].explanation.numpy()
return hm
#TODO
[docs]
def save(self):
raise NotImplementedError
#TODO
[docs]
def load(self):
raise NotImplementedError