# Get Python six functionality:
from __future__ import\
absolute_import, print_function, division, unicode_literals
from builtins import range
###############################################################################
###############################################################################
###############################################################################
import matplotlib.pyplot as plt
import numpy as np
__all__ = [
"project",
"heatmap",
"graymap",
"gamma",
"clip_quantile",
]
###############################################################################
###############################################################################
###############################################################################
[docs]
def project(X, output_range=(0, 1), absmax=None, input_is_positive_only=False):
"""Projects a tensor into a value range.
Projects the tensor values into the specified range.
:param X: A tensor.
:param output_range: The output value range.
:param absmax: A tensor specifying the absmax used for normalizing.
Default the absmax along the first axis.
:param input_is_positive_only: Is the input value range only positive.
:return: The tensor with the values project into output range.
"""
if absmax is None:
absmax = np.max(np.abs(X),
axis=tuple(range(1, len(X.shape))))
absmax = np.asarray(absmax)
mask = absmax != 0
if mask.sum() > 0:
X[mask] /= absmax[mask]
if input_is_positive_only is False:
X = (X+1)/2 # [0, 1]
X = X.clip(0, 1)
X = output_range[0] + (X * (output_range[1]-output_range[0]))
return X
[docs]
def heatmap(X, cmap_type="seismic", reduce_op="sum", reduce_axis=-1, alpha_cmap=False, **kwargs):
"""Creates a heatmap/color map.
Create a heatmap or colormap out of the input tensor.
:param X: A image tensor with 4 axes.
:param cmap_type: The color map to use. Default 'seismic'.
:param reduce_op: Operation to reduce the color axis.
Either 'sum' or 'absmax'.
:param reduce_axis: Axis to reduce.
:param alpha_cmap: Should the alpha component of the cmap be included.
:param kwargs: Arguments passed on to :func:`project`
:return: The tensor as color-map.
"""
cmap = plt.cm.get_cmap(cmap_type)
tmp = X
shape = tmp.shape
if reduce_op == "sum":
tmp = tmp.sum(axis=reduce_axis)
elif reduce_op == "absmax":
pos_max = tmp.max(axis=reduce_axis)
neg_max = (-tmp).max(axis=reduce_axis)
abs_neg_max = -neg_max
tmp = np.select([pos_max >= abs_neg_max, pos_max < abs_neg_max],
[pos_max, neg_max])
else:
raise NotImplementedError()
tmp = project(tmp, output_range=(0, 255), **kwargs).astype(np.int64)
if alpha_cmap:
tmp = cmap(tmp.flatten()).T
else:
tmp = cmap(tmp.flatten())[:, :3].T
tmp = tmp.T
shape = list(shape)
shape[reduce_axis] = 3 + alpha_cmap
return tmp.reshape(shape).astype(np.float32)
[docs]
def graymap(X, **kwargs):
"""Same as :func:`heatmap` but uses a gray colormap."""
return heatmap(X, cmap_type="gray", **kwargs)
[docs]
def gamma(X, gamma=0.5, minamp=0, maxamp=None):
"""
Apply gamma correction to an input array X
while maintaining the relative order of entries,
also for negative vs positive values in X.
the fxn firstly determines the max
amplitude in both positive and negative
direction and then applies gamma scaling
to the positive and negative values of the
array separately, according to the common amplitude.
:param gamma: the gamma parameter for gamma scaling
:param minamp: the smallest absolute value to consider.
if not given assumed to be zero (neutral value for relevance,
min value for saliency, ...). values above and below
minamp are treated separately.
:param maxamp: the largest absolute value to consider relative
to the neutral value minamp
if not given determined from the given data.
"""
#prepare return array
Y = np.zeros_like(X)
X = X - minamp # shift to given/assumed center
if maxamp is None: maxamp = np.abs(X).max() #infer maxamp if not given
X = X / maxamp # scale linearly
#apply gamma correction for both positive and negative values.
i_pos = X > 0
i_neg = np.invert(i_pos)
Y[i_pos] = X[i_pos]**gamma
Y[i_neg] = -(-X[i_neg])**gamma
#reconstruct original scale and center
Y *= maxamp
Y += minamp
return Y
[docs]
def clip_quantile(X, quantile=1):
"""Clip the values of X into the given quantile."""
if not isinstance(quantile, (list, tuple)):
quantile = (quantile, 100-quantile)
low = np.percentile(X, quantile[0])
high = np.percentile(X, quantile[1])
X[X < low] = low
X[X > high] = high
return X