Source code for signxai.tf_signxai.methods_impl.innvestigate.analyzer.reverse_map

import warnings
import tensorflow as tf

###############################################################################
###############################################################################
###############################################################################


import tensorflow.keras.backend as K
import numpy as np
from ..utils.keras import graph as kgraph

import tensorflow.keras.layers as keras_layers

[docs] class ReplacementLayer(): """ Base class for providing explainability functionality. This class wraps a single network layer, providing hooks for applying the layer and retrieving an explanation. Basically: * Any forward passes required for computing the explanation are defined in apply method. During the forward pass, a callback is given to all child layers to retrieve their explanations * Wrappers (e.g., a GradientTape) around the forward pass(es) that are required to compute an explanation can be defined and returned by wrap_hook method * In wrap_hook method, forward pass(es) are applied and Tapes defined * Explanation is computed in explain_hook method and then passed to callback functions of parent ReplacementLayers :param layer: Layer (of base class tensorflow.keras.layers.Layer) of to wrap :param layer_next: List of Layers in the network that receive output of layer (=child layers) :param r_init_constant: If not None, defines a constant output mapping :param f_init_constant: If not None, defines a constant activation mapping This is just a base class. To extend this for specific XAI methods, - wrap_hook() - explain_hook() should be overwritten accordingly """
[docs] def __init__(self, layer, layer_next=[], r_init_constant=None, f_init_constant=None, **kwargs): self.layer_func = layer self.layer_next = layer_next self.name = layer.name self.r_init = r_init_constant self.f_init = f_init_constant # Handle input_shape compatibility for modern TensorFlow/Keras if hasattr(layer, 'input_shape'): self.input_shape = layer.input_shape elif hasattr(layer, '_inbound_nodes') and layer._inbound_nodes: # Modern Keras - get input shapes from inbound nodes try: input_shapes = [tensor.shape for tensor in layer._inbound_nodes[0].input_tensors] self.input_shape = input_shapes[0] if len(input_shapes) == 1 else input_shapes except (IndexError, AttributeError): # Fallback to input_spec if available if hasattr(layer, 'input_spec') and layer.input_spec: if hasattr(layer.input_spec, 'shape') and layer.input_spec.shape: self.input_shape = layer.input_spec.shape else: # Fallback shape self.input_shape = (None,) else: self.input_shape = (None,) else: self.input_shape = (None,) if not isinstance(self.input_shape, list): self.input_shape = [self.input_shape] # Handle output_shape compatibility for modern TensorFlow/Keras if hasattr(layer, 'output_shape'): self.output_shape = layer.output_shape elif hasattr(layer, '_inbound_nodes') and layer._inbound_nodes: # Modern Keras - get output shapes from inbound nodes try: output_shapes = [tensor.shape for tensor in layer._inbound_nodes[0].output_tensors] self.output_shape = output_shapes[0] if len(output_shapes) == 1 else output_shapes except (IndexError, AttributeError): # Fallback shape based on layer type self.output_shape = (None,) else: self.output_shape = (None,) if not isinstance(self.output_shape, list): self.output_shape = [self.output_shape] self.input_vals = None self.original_output_vals = None self.reversed_output_vals = None self.callbacks = None self.hook_vals = None self.explanation = None ############### # TODO: remove one of the variables self.forward_after_stopping = True self.base_forward_after_stopping = False ############# self.reached_after_stop_mapping = None
[docs] def try_explain(self, reversed_outs): """ callback function called by child layers when their explanation is computed. * aggregates explanations of all children * calls explain_hook to compute own explanation * sends own explanation to all parent layers by calling their callback functions :param reversed_outs: the child layer's explanation. None if this is the last layer. """ # aggregate explanations if reversed_outs is not None: if self.reversed_output_vals is None: self.reversed_output_vals = [] self.reversed_output_vals.append(reversed_outs) # last layer or aggregation finished if self.reversed_output_vals is None or len(self.reversed_output_vals) == len(self.layer_next): # apply post hook: explain if self.hook_vals is None: raise ValueError( "self.hook_vals should contain values at this point. Is self.wrap_hook working correctly?") input_vals = self.input_vals if len(input_vals) == 1: input_vals = input_vals[0] rev_outs = self.reversed_output_vals if rev_outs is not None: if len(rev_outs) == 1: rev_outs = rev_outs[0] # print("Backward:", self.name) # print(self.name, len(self.layer_next)) # print(self.name, np.shape(input_vals), np.shape(rev_outs), np.shape(self.hook_vals[0])) self.explanation = self.explain_hook(input_vals, rev_outs, self.hook_vals) # callbacks if self.callbacks is not None: # check if multiple inputs explained if len(self.callbacks) > 1 and not isinstance(self.explanation, list): raise ValueError(self.name + ": This layer has " + str( len(self.callbacks)) + " inputs, but no list of explanations was provided.") elif len(self.callbacks) > 1 and len(self.callbacks) != len(self.explanation): raise ValueError( self.name + ": This layer has " + str(len(self.callbacks)) + " inputs, but only " + str( len(self.explanation)) + " explanations were computed") if len(self.callbacks) > 1: for c, callback in enumerate(self.callbacks): callback(self.explanation[c]) else: self.callbacks[0](self.explanation) # reset self.input_vals = None self.reversed_output_vals = None self.callbacks = None self.hook_vals = None
def _forward(self, Ys, neuron_selection=None, stop_mapping_at_layers=None, r_init=None, f_init=None): """ Forward Pass to all child layers * If this is the last layer, directly calls try_explain to compute own explanation * Otherwise calls try_apply on all child layers :param Ys: output of own forward pass :param neuron_selection: neuron_selection parameter (see try_apply) :param stop_mapping_at_layers: stop_mapping_at_layers parameter (see try_apply) :param r_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} reverse initialization value. Value with with explanation is initialized (i.e., head_mapping). :param f_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} forward initialization value. Value with which the forward is initialized. """ #print("Forward: ", self.name) if len(self.layer_next) == 0 : # last layer: directly compute explanation self.try_explain(None) elif stop_mapping_at_layers is not None and self.name in stop_mapping_at_layers: self.try_explain(None) ######################### # TODO: New Code, remove comments if it is good #if self.base_forward_after_stopping: # for layer_n in self.layer_next: # layer_n.try_apply(Ys, None, neuron_selection, stop_mapping_at_layers, r_init) ######################### # TODO: Leander checks if this code is still necessary if self.forward_after_stopping: #if the output was mapped, this restores the original for the forwarding if self.original_output_vals is not None: Ys = self.original_output_vals self.original_output_vals = None #make a dummy callback so that logic does not get buggy def dummyCallback(reversed_outs): pass for layer_n in self.layer_next: layer_n.try_apply(Ys, dummyCallback, neuron_selection, stop_mapping_at_layers, r_init, f_init) ############################# else: # forward for layer_n in self.layer_next: layer_n.try_apply(Ys, self.try_explain, neuron_selection, stop_mapping_at_layers, r_init, f_init) @tf.custom_gradient def _toNumber(self, X, value): """ Helper function to set a Tensor to a fixed value while having a gradient of "value" """ y = tf.constant(value, dtype=tf.float32, shape=X.shape) def grad(dy, variables=None): # variables=None and None as output necessary as toNumber requires two arguments return dy * tf.ones(X.shape) * value, None return y, grad def _neuron_select(self, Ys, neuron_selection): """ Performs neuron_selection on Ys :param Ys: output of own forward pass :param neuron_selection: neuron_selection parameter (see try_apply) """ #error handling is done before, in try_apply if neuron_selection is None: Ys = Ys elif isinstance(neuron_selection, tf.Tensor): # flatten and then filter neuron_selection index Ys = tf.reshape(Ys, (Ys.shape[0], np.prod(Ys.shape[1:]))) Ys = tf.gather_nd(Ys, neuron_selection, batch_dims=1) else: Ys = K.max(Ys, axis=1, keepdims=True) return Ys def _head_mapping(self, Ys, model_output_value=None): """ Sets the model output to a fixed value. Used as initialization for the explanation method. :param model_output_value: output value of model / initialized value for explanation method """ if model_output_value is not None: if isinstance(model_output_value, dict): if self.name in model_output_value.keys(): # model_output_value should be int or array-like. Shape should fit. Ys = self._toNumber(Ys, model_output_value[self.name]) else: # model_output_value should be int or array-like. Shape should fit. Ys = self._toNumber(Ys, model_output_value) return Ys def _neuron_sel_and_head_map(self, Ys, neuron_selection=None, model_output_value=None): #save original output self.original_output_vals = Ys #apply neuron selection and head mapping Ys = self._neuron_select(Ys, neuron_selection) Ys = self._head_mapping(Ys, model_output_value) return Ys
[docs] def try_apply(self, ins, callback=None, neuron_selection=None, stop_mapping_at_layers=None, r_init=None, f_init=None): """ Tries to apply own forward pass: * Aggregates inputs and callbacks of all parent layers * Performs a canonization of the neuron_selection parameter * Calls wrap_hook (wrapped forward pass(es)) * Calls _forward (forward result of forward pass to child layers) :param ins output of own forward pass :param neuron_selection: neuron_selection parameter. One of the following: - None or "all" - "max_activation" - int - list or np.array of int, with length equal to batch size :param stop_mapping_at_layers: None or list of layers to stop mapping at :param callback callback function of the parent layer that called self.try_apply :param r_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} reverse initialization value. Value with with explanation is initialized (i.e., head_mapping). :param f_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} forward initialization value. Value with which the forward is initialized. """ # DEBUG # print(self.name, self.input_shape, np.shape(ins)) #uses the class attribute, if it is not None. if self.r_init is not None: r_init = self.r_init if self.f_init is not None: f_init = self.f_init # aggregate inputs if self.input_vals is None: self.input_vals = [] self.input_vals.append(ins) # aggregate callbacks if callback is not None: if self.callbacks is None: self.callbacks = [] self.callbacks.append(callback) # reset explanation self.explanation = None # apply layer only if all inputs collected. Then reset inputs if len(self.input_vals) == len(self.input_shape): # set inputs to f_init, if it is not None if f_init is not None: if isinstance(f_init, dict): if self.name in f_init.keys(): # f_init should be int or array-like. Shape should fit. for i, in_val in enumerate(self.input_vals): self.input_vals[i] = self._toNumber(in_val, f_init[self.name]) else: # f_init should be int or array-like. Shape should fit. for i, in_val in enumerate(self.input_vals): self.input_vals[i] = self._toNumber(in_val, f_init) # tensorify wrap_hook inputs as much as possible for graph efficiency input_vals = self.input_vals if len(input_vals) == 1: input_vals = input_vals[0] # adapt neuron_selection param if len(self.layer_next) == 0 or (stop_mapping_at_layers is not None and self.name in stop_mapping_at_layers): if neuron_selection is None: neuron_selection_tmp = None elif isinstance(neuron_selection, str) and neuron_selection == "all": neuron_selection_tmp = None elif isinstance(neuron_selection, str) and neuron_selection == "max_activation": neuron_selection_tmp = "max_activation" elif isinstance(neuron_selection, int): neuron_selection_tmp = [[neuron_selection] for n in range(self.input_vals[0].shape[0])] neuron_selection_tmp = tf.constant(neuron_selection_tmp) elif isinstance(neuron_selection, list) or ( hasattr(neuron_selection, "shape") and len(neuron_selection.shape) == 1): neuron_selection_tmp = [[n] for n in neuron_selection] neuron_selection_tmp = tf.constant(neuron_selection_tmp) else: raise ValueError( "Parameter neuron_selection only accepts the following values: None, 'all', 'max_activation', <int>, <list>, <one-dimensional array>") else: neuron_selection_tmp = neuron_selection # apply and wrappers self.hook_vals = self.wrap_hook(input_vals, neuron_selection_tmp, stop_mapping_at_layers, r_init) # forward if isinstance(self.hook_vals, tuple): self._forward(self.hook_vals[0], neuron_selection, stop_mapping_at_layers, r_init, f_init) else: self._forward(self.hook_vals, neuron_selection, stop_mapping_at_layers, r_init, f_init)
[docs] def wrap_hook(self, ins, neuron_selection, stop_mapping_at_layers, r_init): """ hook that wraps and applies the layer function. E.g., by defining a GradientTape * should contain a call to self._neuron_select. * may define any wrappers around :param ins: input(s) of this layer :param neuron_selection: neuron_selection parameter (see try_apply) :param stop_mapping_at_layers: None or stop_mapping_at_layers parameter (see try_apply) :param r_init: reverse initialization value. Value with with explanation is initialized (i.e., head_mapping). :returns output of layer function + any wrappers that were defined and are needed in explain_hook To be extended for specific XAI methods """ outs = self.layer_func(ins) # check if final layer (i.e., no next layers) if len(self.layer_next) == 0 or (stop_mapping_at_layers is not None and self.name in stop_mapping_at_layers): outs = self._neuron_sel_and_head_map(outs, neuron_selection, r_init) return outs
[docs] def explain_hook(self, ins, reversed_outs, args): """ hook that computes the explanations. * Core XAI functionality :param ins: input(s) of this layer :param args: outputs of wrap_hook (any parameters that may be needed to compute explanation) :param reversed_outs: either backpropagated explanation(s) of child layers, or None if this is the last layer :returns explanation, or tensor of multiple explanations if the layer has multiple inputs (one for each) To be extended for specific XAI methods """ outs = args if reversed_outs is None: reversed_outs = outs if len(self.layer_next) > 1: #TODO is this addition correct? ret = keras_layers.Add(dtype=tf.float32)([r for r in reversed_outs]) elif len(self.input_shape) > 1: ret = [reversed_outs for i in self.input_shape] ret = tf.keras.layers.concatenate(ret, axis=1) else: ret = reversed_outs return ret
[docs] class GradientReplacementLayer(ReplacementLayer): """ Simple extension of ReplacementLayer * Explains by computing gradients of outputs w.r.t. inputs of layer """
[docs] def __init__(self, *args, **kwargs): super(GradientReplacementLayer, self).__init__(*args, **kwargs)
[docs] def wrap_hook(self, ins, neuron_selection, stop_mapping_at_layers, r_init): with tf.GradientTape(persistent=True) as tape: tape.watch(ins) outs = self.layer_func(ins) # check if final layer (i.e., no next layers) if len(self.layer_next) == 0 or (stop_mapping_at_layers is not None and self.name in stop_mapping_at_layers): outs = self._neuron_sel_and_head_map(outs, neuron_selection=neuron_selection, model_output_value=r_init) # print('GradientReplacementLayer init:', outs, r_init) return outs, tape
[docs] def explain_hook(self, ins, reversed_outs, args): outs, tape = args if reversed_outs is None: reversed_outs = outs # correct number of outs if len(self.layer_next) > 1: outs = [outs for l in self.layer_next] if len(self.layer_next) > 1: if len(self.input_shape) > 1: ret = [keras_layers.Add(dtype=tf.float32)([tape.gradient(o, i, output_gradients=r) for o, r in zip(outs, reversed_outs)]) for i in ins] else: ret = keras_layers.Add(dtype=tf.float32)([tape.gradient(o, ins, output_gradients=r) for o, r in zip(outs, reversed_outs)]) else: if len(self.input_shape) > 1: ret = [tape.gradient(outs, i, output_gradients=reversed_outs) for i in ins] else: ret = tape.gradient(outs, ins, output_gradients=reversed_outs) return ret
[docs] class ReverseModel(): """ Defines a ReverseModel ReverseModels are built from ReplacementLayer subclasses. A ReverseModel is defined via a list of Input ReplacementLayers (the input layers of the model) and ReplacementLayers (the whole model) Offers methods to - build - apply - get precomputed explanations from - save - load the ReverseModel """
[docs] def __init__(self, model, reverse_mappings, default_reverse_mapping, **kwargs): self.build(model, reverse_mappings, default_reverse_mapping, **kwargs)
[docs] def build(self, model, reverse_mappings, default_reverse_mapping, **kwargs): """ Builds the ReverseModel by wrapping keras network layer(s) into ReplacementLayer(s) :param model: tf.keras model to be replaced :param reverse_mappings: mapping layer->reverse mapping (ReplacementLayer or some subclass thereof) :param default_reverse_mapping: ReplacementLayer or some subclass thereof; default mapping to use :returns - """ # build model that is to be analyzed layers = kgraph.get_model_layers(model) # set all replacement layers replacement_layers = [] for layer in layers: layer_next = [] wrapper_class = reverse_mappings(layer) if wrapper_class is None: wrapper_class = default_reverse_mapping(layer) if not issubclass(wrapper_class, ReplacementLayer): raise ValueError("Reverse Mappings should be an instance of ReplacementLayer") replacement_layers.append(wrapper_class(layer, layer_next, **kwargs)) # connect graph structure for layer in replacement_layers: for layer2 in replacement_layers: inp = layer2.layer_func.input out = layer.layer_func.output if not isinstance(inp, list): inp = [inp] if not isinstance(out, list): out = [out] for i in inp: if id(i) in [id(o) for o in out] and id(layer) != id(layer2): layer.layer_next.append(layer2) # find input access points input_layers = [] input_layers.append(replacement_layers[0]) # TODO: HOTFIX for input layer identification # for i, t in enumerate(model.inputs): # for layer in replacement_layers: # if id(layer.layer_func.output) == id(t): # input_layers.append(layer) # if len(input_layers) < i + 1: # # if we did not append an input layer, we need to create one # # TODO case for no input layer here # raise ValueError("Temporary error. You need to explicitly define an Input Layer for now") self._reverse_model = (input_layers, replacement_layers)
[docs] def apply(self, Xs, neuron_selection="max_activation", explained_layer_names=None, stop_mapping_at_layers=None, r_init=None, f_init=None): """ Computes an explanation by applying the ReverseModel :param Xs: tensor or np.array of Input to be explained. Shape (n_ins, batch_size, ...) in model has multiple inputs, or (batch_size, ...) otherwise :param neuron_selection: neuron_selection parameter. Used to only compute explanation w.r.t. specific output neurons. One of the following: - None or "all" - "max_activation" - int - list or np.array of int, with length equal to batch size :param explained_layer_names: None or "all" or list of layer names whose explanations should be returned. Can be used to obtain intermediate explanations or explanations of multiple layers :param stop_mapping_at_layers: None or list of layers to stop mapping at ("output" layers) :param r_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} reverse initialization value. Value with which the explanation is initialized. :param f_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} forward initialization value. Value with which the forward is initialized. :returns Dict of the form {layer name (string): explanation (numpy.ndarray)} """ # shape of Xs: (n_ins, batch_size, ...), or (batch_size, ...) reverse_ins, reverse_layers = self._reverse_model warnings.simplefilter("always") if stop_mapping_at_layers is not None and (isinstance(neuron_selection, int) or isinstance(neuron_selection, list) or isinstance(neuron_selection, np.ndarray)): warnings.warn("You are specifying layers to stop forward pass at, and also neuron-selecting by index. Please make sure the corresponding shapes fit together!") if not isinstance(Xs, tf.Tensor): try: Xs = tf.constant(Xs) except: raise ValueError("Xs has not supported type ", type(Xs)) # format input & obtain explanations if len(reverse_ins) == 1: # single input network reverse_ins[0].try_apply(tf.constant(Xs), neuron_selection=neuron_selection, stop_mapping_at_layers=stop_mapping_at_layers, r_init=r_init, f_init=f_init) else: # multiple inputs. reshape to (n_ins, batch_size, ...) for i, reverse_in in enumerate(reverse_ins): reverse_in.try_apply(tf.constant(Xs[i]), neuron_selection=neuron_selection, stop_mapping_at_layers=stop_mapping_at_layers, r_init=r_init, f_init=f_init) # obtain explanations for specified layers hm = self.get_explanations(explained_layer_names) return hm
[docs] def get_explanations(self, explained_layer_names=None): """ Get results of (previously computed) explanation. explanation of layer i has shape equal to input_shape of layer i. :param explained_layer_names: None or "all" or list of strings containing the names of the layers. if explained_layer_names == 'all', explanations of all layers are returned. :returns Dict of the form {layer name (string): explanation (numpy.ndarray)} """ reverse_ins, reverse_layers = self._reverse_model hm = {} if explained_layer_names is None: # just explain input layers for layer in reverse_ins: hm[layer.name] = layer.explanation.numpy() return hm # output everything possible if explained_layer_names == "all": for layer in reverse_layers: if layer.explanation is not None: hm[layer.name] = layer.explanation.numpy() return hm # otherwise, obtain explanations for specified layers for name in explained_layer_names: layer = [layer for layer in reverse_layers if layer.name == name] if len(layer) > 0: if layer[0].explanation is None: raise AttributeError(f"layer <<{name}>> has to be analyzed before") hm[name] = layer[0].explanation.numpy() return hm
#TODO
[docs] def save(self): raise NotImplementedError
#TODO
[docs] def load(self): raise NotImplementedError