from __future__ import \
absolute_import, print_function, division, unicode_literals
###############################################################################
###############################################################################
###############################################################################
###############################################################################
###############################################################################
###############################################################################
import tensorflow.keras.layers as keras_layers
import warnings
import inspect
from .. import utils as iutils
from . import reverse_map
from ..utils.keras import checks as kchecks
from ..utils.keras import graph as kgraph
__all__ = [
"NotAnalyzeableModelException",
"AnalyzerBase",
"TrainerMixin",
"OneEpochTrainerMixin",
"AnalyzerNetworkBase",
"ReverseAnalyzerBase"
]
###############################################################################
###############################################################################
###############################################################################
[docs]
class NotAnalyzeableModelException(Exception):
"""Indicates that the model cannot be analyzed by an analyzer."""
pass
[docs]
class AnalyzerBase(object):
""" The basic interface of an iNNvestigate analyzer.
This class defines the basic interface for analyzers:
# >>> model = create_keras_model()
# >>> a = Analyzer(model)
# >>> a.fit(X_train) # If analyzer needs training.
# >>> analysis = a.analyze(X_test)
# >>>
# >>> state = a.save()
# >>> a_new = A.load(*state)
# >>> analysis = a_new.analyze(X_test)
:param model: A tf.keras model.
:param disable_model_checks: Do not execute model checks that enforce
compatibility of analyzer and model.
.. note:: To develop a new analyzer derive from
:class:`AnalyzerNetworkBase`.
"""
# def __init__(self, model, disable_model_checks=False):
[docs]
def __init__(self, model, disable_model_checks=False, **kwargs):
self._model = model
self._disable_model_checks = disable_model_checks
self._do_model_checks()
def _add_model_check(self, check, message, check_type="exception"):
if getattr(self, "_model_check_done", False):
raise Exception("Cannot add model check anymore."
" Check was already performed.")
if not hasattr(self, "_model_checks"):
self._model_checks = []
check_instance = {
"check": check,
"message": message,
"type": check_type,
}
self._model_checks.append(check_instance)
def _do_model_checks(self):
model_checks = getattr(self, "_model_checks", [])
if not self._disable_model_checks and len(model_checks) > 0:
check = [x["check"] for x in model_checks]
types = [x["type"] for x in model_checks]
messages = [x["message"] for x in model_checks]
checked = kgraph.model_contains(self._model, check)
tmp = zip(iutils.to_list(checked), messages, types)
for checked_layers, message, check_type in tmp:
if len(checked_layers) > 0:
tmp_message = ("%s\nCheck triggerd by layers: %s" %
(message, checked_layers))
if check_type == "exception":
raise NotAnalyzeableModelException(tmp_message)
elif check_type == "warning":
# TODO(albermax) only the first warning will be shown
warnings.warn(tmp_message)
else:
raise NotImplementedError()
self._model_check_done = True
[docs]
def fit(self, *args, **kwargs):
"""
Stub that eats arguments. If an analyzer needs training
include :class:`TrainerMixin`.
:param disable_no_training_warning: Do not warn if this function is
called despite no training is needed.
"""
disable_no_training_warning = kwargs.pop("disable_no_training_warning",
False)
if not disable_no_training_warning:
# issue warning if not training is foreseen,
# but is fit is still called.
warnings.warn("This analyzer does not need to be trained."
" Still fit() is called.", RuntimeWarning)
[docs]
def fit_generator(self, *args, **kwargs):
"""
Stub that eats arguments. If an analyzer needs training
include :class:`TrainerMixin`.
:param disable_no_training_warning: Do not warn if this function is
called despite no training is needed.
"""
disable_no_training_warning = kwargs.pop("disable_no_training_warning",
False)
if not disable_no_training_warning:
# issue warning if not training is foreseen,
# but is fit is still called.
warnings.warn("This analyzer does not need to be trained."
" Still fit_generator() is called.", RuntimeWarning)
[docs]
def analyze(self, X):
"""
Analyze the behavior of model on input `X`.
:param X: Input as expected by model.
"""
raise NotImplementedError()
###############################################################################
###############################################################################
###############################################################################
[docs]
class TrainerMixin(object):
"""Mixin for analyzer that adapt to data.
This convenience interface exposes a Keras like training routing
to the user.
"""
# todo: extend with Y
[docs]
def fit(self,
X=None,
batch_size=32,
**kwargs):
"""
Takes the same parameters as Keras's :func:`model.fit` function.
"""
generator = iutils.BatchSequence(X, batch_size)
return self._fit_generator(generator,
**kwargs)
[docs]
def fit_generator(self, *args, **kwargs):
"""
Takes the same parameters as Keras's :func:`model.fit_generator`
function.
"""
return self._fit_generator(*args, **kwargs)
def _fit_generator(self,
generator,
steps_per_epoch=None,
epochs=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0,
disable_no_training_warning=None):
raise NotImplementedError()
[docs]
class OneEpochTrainerMixin(TrainerMixin):
"""Exposes the same interface and functionality as :class:`TrainerMixin`
except that the training is limited to one epoch.
"""
[docs]
def fit(self, *args, **kwargs):
"""
Same interface as :func:`fit` of :class:`TrainerMixin` except that
the parameter epoch is fixed to 1.
"""
return super(OneEpochTrainerMixin, self).fit(*args, epochs=1, **kwargs)
[docs]
def fit_generator(self, *args, **kwargs):
"""
Same interface as :func:`fit_generator` of :class:`TrainerMixin` except that
the parameter epoch is fixed to 1.
"""
steps = kwargs.pop("steps", None)
return super(OneEpochTrainerMixin, self).fit_generator(
*args,
steps_per_epoch=steps,
epochs=1,
**kwargs)
###############################################################################
###############################################################################
###############################################################################
[docs]
class AnalyzerNetworkBase(AnalyzerBase):
"""Convenience interface for analyzers.
This class provides helpful functionality to create analyzer's.
Basically it:
* takes the input model and adds a layer that selects
the desired output neuron to analyze.
* passes the new model to :func:`_create_analysis` which should
return the analysis as Keras tensors.
* compiles the function and serves the output to :func:`analyze` calls.
* allows :func:`_create_analysis` to return tensors
that are intercept for debugging purposes.
:param allow_lambda_layers: Allow the model to contain lambda layers.
"""
[docs]
def __init__(self, model,
allow_lambda_layers=False,
**kwargs):
self._allow_lambda_layers = allow_lambda_layers
self._analyzed = False
self._add_model_check(
lambda layer: (not self._allow_lambda_layers and
isinstance(layer, keras_layers.Lambda)),
("Lamda layers are not allowed. "
"To force use set allow_lambda_layers parameter."),
check_type="exception",
)
super(AnalyzerNetworkBase, self).__init__(model, **kwargs)
def _add_model_softmax_check(self):
"""
Adds check that prevents models from containing a softmax.
"""
self._add_model_check(
lambda layer: kchecks.contains_activation(
layer, activation="softmax"),
"This analysis method does not support softmax layers.",
check_type="exception",
)
[docs]
def create_analyzer_model(self, **kwargs):
"""
Creates the analyze functionality. If not called beforehand
it will be called by :func:`analyze`.
"""
self._analyzer_model = self._create_analysis(self._model, **kwargs)
def _create_analysis(self, model, **kwargs):
"""
Interface that needs to be implemented by a derived class.
This function is expected to create a custom analysis for the model inputs given the model outputs.
:param model: Target of analysis.
:return: reversed "model" as a list of input layers and a list of wrapped layers
"""
raise NotImplementedError()
def _handle_debug_output(self, debug_values):
raise NotImplementedError()
[docs]
def analyze(self, X, neuron_selection="max_activation", explained_layer_names=None, stop_mapping_at_layers=None,
r_init=None, f_init=None, **kwargs):
"""
Takes an array-like input X and explains it. Also applies postprocessing to the explanation
:param X: tensor or np.array of Input to be explained. Shape (n_ins, batch_size, ...) in model has multiple inputs, or (batch_size, ...) otherwise
:param neuron_selection: neuron_selection parameter. Used to only compute explanation w.r.t. specific output neurons. One of the following:
- None or "all"
- "max_activation"
- int
- list or np.array of int, with length equal to batch size
:param explained_layer_names: None or "all" or list of layer names whose explanations should be returned.
Can be used to obtain intermediate explanations or explanations of multiple layers
if layer names provided, a dictionary is returned
:param stop_mapping_at_layers: None or list of layers to stop mapping at ("output" layers)
:param r_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} reverse initialization value. Value with which the explanation is initialized.
:param f_init: None or Scalar or Array-Like or Dict {layer_name:scalar or array-like} forward initialization value. Value with which the forward is initialized.
:returns Dict of the form {layer name (string): explanation (numpy.ndarray)}
"""
if not hasattr(self, "_analyzer_model"):
self.create_analyzer_model(**kwargs)
if isinstance(explained_layer_names, list):
for l in explained_layer_names:
if not isinstance(l, str):
raise AttributeError("Parameter explained_layer_names has to be None, 'all', or a list of strings")
elif explained_layer_names is not None and explained_layer_names is not 'all':
# not list and not None
raise AttributeError("Parameter explained_layer_names has to be None, 'all', or a list of strings")
if isinstance(stop_mapping_at_layers, list):
for l in stop_mapping_at_layers:
if not isinstance(l, str):
raise AttributeError("Parameter stop_mapping_at_layers has to be None or a list of strings")
elif stop_mapping_at_layers is not None:
# not list and not None
raise AttributeError("Parameter stop_mapping_at_layers has to be None or a list of strings")
# check if a layer before layers in stop_mapping_layers are connected to layers
# after stop_mapping_at_layers
# if yes, forward pass has to be done for every layer in model
if stop_mapping_at_layers is not None:
in_layers, rev_layer = self._analyzer_model._reverse_model
for il in in_layers:
if self._is_resnet_like(il, stop_mapping_at_layers, False) == 0:
for rl in rev_layer:
rl.base_forward_after_stopping = True
ret = self._analyzer_model.apply(X,
neuron_selection=neuron_selection,
explained_layer_names=explained_layer_names,
stop_mapping_at_layers=stop_mapping_at_layers,
r_init=r_init,
f_init=f_init
)
self._analyzed = True
ret = self._postprocess_analysis(ret)
return ret
def _postprocess_analysis(self, hm):
return hm
def _is_resnet_like(self, layer, stop_mapping_at_layers, after_stop_mapping):
"""
recursive function to check if there are layers that have connections reaching layers behind stop_mapping_at_layers
param layer: start point
"""
next_layers = layer.layer_next
if len(next_layers) == 0:
# reached last node, return "everything ok" as default
return 1
# current layer is part of stop mapping
if stop_mapping_at_layers is not None and layer.name in stop_mapping_at_layers:
# boolean signifies whether next layers are after a stop mapping layer
after_stop_mapping = True
result_child = 1
for nl in next_layers:
if nl.reached_after_stop_mapping is not None:
# next layer already visited before
if nl.reached_after_stop_mapping != after_stop_mapping:
# layer before stop mapping is connected to layer after stop mapping!
# conflict!!
return 0
if nl.reached_after_stop_mapping is None:
# first time next layer is visited
if after_stop_mapping is True:
# next layer is after stop mapping layer
nl.reached_after_stop_mapping = True
else:
# next layer is not after stop mapping layer
nl.reached_after_stop_mapping = False
result_child = result_child and self._is_resnet_like(nl, stop_mapping_at_layers, after_stop_mapping)
return result_child
[docs]
def get_explanations(self, explained_layer_names=None):
"""
Get results of (previously computed) explanation.
explanation of layer i has shape equal to input_shape of layer i.
:param explained_layer_names: None or "all" or list of strings containing the names of the layers.
if explained_layer_names == 'all', explanations of all layers are returned.
:returns Dict of the form {layer name (string): explanation (numpy.ndarray)}
"""
if not hasattr(self, "_analyzer_model"):
self.create_analyzer_model()
if not self._analyzed:
raise AttributeError("You have to analyze the model before intermediate results are available!")
if isinstance(explained_layer_names, list):
for l in explained_layer_names:
if not isinstance(l, str):
raise AttributeError("Parameter explained_layer_names has to be None or a list of strings")
elif explained_layer_names is not None:
# not list and not None
raise AttributeError("Parameter explained_layer_names has to be None or a list of strings")
hm = self._analyzer_model.get_explanations(explained_layer_names)
hm = self._postprocess_analysis(hm)
return hm
[docs]
class ReverseAnalyzerBase(AnalyzerNetworkBase):
"""Convenience class for analyzers that revert the model's structure.
This class contains many helper functions around the graph
reverse function :func:`innvestigate.utils.keras.graph.reverse_model`.
The deriving classes should specify how the graph should be reverted.
"""
[docs]
def __init__(self,
model,
**kwargs):
super(ReverseAnalyzerBase, self).__init__(model, **kwargs)
def _gradient_reverse_mapping(self):
return reverse_map.GradientReplacementLayer
def _reverse_mapping(self, layer):
"""
This function should return a reverse mapping for the passed layer.
If this function returns None, :func:`_default_reverse_mapping`
is applied.
:param layer: The layer for which a mapping should be returned.
:return: The mapping can be of the following forms:
* A :class:`ReplacementLayer` subclass.
"""
return self._apply_conditional_reverse_mappings(layer)
def _add_conditional_reverse_mapping(
self, condition, mapping, priority=-1, name=None):
"""
This function should return a reverse mapping for the passed layer.
If this function returns None, :func:`_default_reverse_mapping`
is applied.
:param condition: Condition when this mapping should be applied.
Form: f(layer) -> bool
:param mapping: The mapping can be of the following forms:
* A function of form f(layer) that returns
a class:`ReverseMappingBase` subclass..
* A :class:`ReverseMappingBase` subclass.
:param priority: The higher the earlier the condition gets
evaluated.
:param name: An identifying name.
"""
if getattr(self, "_reverse_mapping_applied", False):
raise Exception("Cannot add conditional mapping "
"after first application.")
if not hasattr(self, "_conditional_reverse_mappings"):
self._conditional_reverse_mappings = {}
if priority not in self._conditional_reverse_mappings:
self._conditional_reverse_mappings[priority] = []
tmp = {"condition": condition, "mapping": mapping, "name": name}
self._conditional_reverse_mappings[priority].append(tmp)
def _apply_conditional_reverse_mappings(self, layer):
mappings = getattr(self, "_conditional_reverse_mappings", {})
self._reverse_mapping_applied = True
# Search for mapping. First consider ones with highest priority,
# inside priority in order of adding.
sorted_keys = sorted(mappings.keys())[::-1]
for key in sorted_keys:
for mapping in mappings[key]:
if mapping["condition"](layer):
# print('Selecting Reverse Mapping for ', layer.__class__.__name__)
if (inspect.isclass(mapping["mapping"]) and issubclass(mapping["mapping"], reverse_map.ReplacementLayer)):
return mapping["mapping"]
elif callable(mapping["mapping"]):
return mapping["mapping"](layer)
# print('No Reverse Mapping applies for ', layer.__class__.__name__)
return None
def _default_reverse_mapping(self, layer):
"""
Fallback function to map layer
"""
return reverse_map.GradientReplacementLayer
def _create_analysis(self, model, **kwargs):
analyzer_model = reverse_map.ReverseModel(
model,
reverse_mappings=self._reverse_mapping,
default_reverse_mapping=self._default_reverse_mapping,
**kwargs
)
return analyzer_model