Source code for signxai.tf_signxai.methods_impl.innvestigate.utils.tests.cases.trivia

# Get Python six functionality:
from __future__ import\
    absolute_import, print_function, division, unicode_literals


###############################################################################
###############################################################################
###############################################################################


import numpy as np

from .... import backend

from . import helper


__all__ = [
    "dot",
    "skip_connection",
]


###############################################################################
###############################################################################
###############################################################################


[docs] def dot(): input_shape = (1, 2) data = np.random.rand(*input_shape) if backend.name() == "tensorflow": layers = backend.keras.layers inputs = layers.Input(shape=input_shape[1:]) outputs = layers.Dense(units=1, activation="linear")(inputs) model = helper.build_keras_model(inputs, outputs) else: raise NotImplementedError() return model, data
[docs] def skip_connection(): input_shape = (1, 1) data = np.random.rand(*input_shape) if backend.name() == "tensorflow": layers = backend.keras.layers inputs = layers.Input(shape=input_shape[1:]) tmp = layers.Dense(units=1, activation="linear")(inputs) outputs = layers.Add()([inputs, tmp]) model = helper.build_keras_model(inputs, outputs) else: raise NotImplementedError() return model, data