from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter
from ..utils import _raise, move_channel_for_backend, axes_dict, axes_check_and_normalize, backend_channels_last
from ..internals.losses import loss_laplace, loss_mse, loss_mae, loss_thresh_weighted_decay
import numpy as np
from ..utils.tf import keras_import, BACKEND as K
Callback, TerminateOnNaN = keras_import('callbacks', 'Callback', 'TerminateOnNaN')
Sequence = keras_import('utils', 'Sequence')
Optimizer = keras_import('optimizers', 'Optimizer')
class ParameterDecayCallback(Callback):
""" TODO """
def __init__(self, parameter, decay, name=None, verbose=0):
self.parameter = parameter
self.decay = decay
self.name = name
self.verbose = verbose
def on_epoch_end(self, epoch, logs=None):
old_val = K.get_value(self.parameter)
if self.name:
logs = logs or {}
logs[self.name] = old_val
new_val = old_val * (1. / (1. + self.decay * (epoch + 1)))
K.set_value(self.parameter, new_val)
if self.verbose:
print("\n[ParameterDecayCallback] new %s: %s\n" % (self.name if self.name else 'parameter', new_val))
[docs]def prepare_model(model, optimizer, loss, metrics=('mse','mae'),
loss_bg_thresh=0, loss_bg_decay=0.06, Y=None):
""" TODO """
isinstance(optimizer,Optimizer) or _raise(ValueError())
loss_standard = eval('loss_%s()'%loss)
_metrics = [eval('loss_%s()'%m) for m in metrics]
callbacks = [TerminateOnNaN()]
# checks
assert 0 <= loss_bg_thresh <= 1
assert loss_bg_thresh == 0 or Y is not None
if loss == 'laplace':
assert K.image_data_format() == "channels_last", "TODO"
assert list(model.output.shape)[-1] >= 2 and list(model.output.shape)[-1] % 2 == 0
# loss
if loss_bg_thresh == 0:
_loss = loss_standard
else:
freq = np.mean(Y > loss_bg_thresh)
# print("class frequency:", freq)
alpha = K.variable(1.0)
loss_per_pixel = eval('loss_{loss}(mean=False)'.format(loss=loss))
_loss = loss_thresh_weighted_decay(loss_per_pixel, loss_bg_thresh,
0.5 / (0.1 + (1 - freq)),
0.5 / (0.1 + freq),
alpha)
callbacks.append(ParameterDecayCallback(alpha, loss_bg_decay, name='alpha'))
if not loss in metrics:
_metrics.append(loss_standard)
# compile model
model.compile(optimizer=optimizer, loss=_loss, metrics=_metrics)
return callbacks
class RollingSequence(Sequence):
"""Helper class for creating batches for rolling sequence.
Create batches of size `batch_size` that contain indices in `range(data_size)`.
To that end, the data indices are repeated (rolling), either in ascending order or
shuffled if `shuffle=True`. If taking batches sequentially, all data indices will
appear equally often. All calls to `batch(i)` will return the same batch for same i.
Parameter `length` will only determine the result of `len`, it has no effect otherwise.
Note that batch_size is allowed to be larger than data_size.
"""
def __init__(self, data_size, batch_size, length=None, shuffle=True, rng=None, keras_kwargs=None):
super(RollingSequence, self).__init__(**({} if keras_kwargs is None else keras_kwargs))
# print(f"### __init__", flush=True)
if rng is None: rng = np.random
self.data_size = int(data_size)
self.batch_size = int(batch_size)
self.length = 2**63-1 if length is None else int(length) # 2**63-1 is max possible value
self.shuffle = bool(shuffle)
self.index_gen = rng.permutation if self.shuffle else np.arange
self.index_map = {}
def __len__(self):
# print(f"### __len__ = {self.length}", flush=True)
return self.length
def _index(self, loop):
if loop in self.index_map:
return self.index_map[loop]
else:
return self.index_map.setdefault(loop, self.index_gen(self.data_size))
def on_epoch_end(self):
# print(f"### on_epoch_end", flush=True)
pass
def __iter__(self):
# print(f"### __iter__", flush=True)
for i in range(len(self)):
yield self[i]
def batch(self, i):
pos = i * self.batch_size
loop = pos // self.data_size
pos_loop = pos % self.data_size
sl = slice(pos_loop, pos_loop + self.batch_size)
index = self._index(loop)
_loop = loop
while sl.stop > len(index):
_loop += 1
index = np.concatenate((index, self._index(_loop)))
# print(f"### - batch({i:02}) -> {tuple(index[sl])}", flush=True)
return index[sl]
def __getitem__(self, i):
return self.batch(i)
class DataWrapper(RollingSequence):
def __init__(self, X, Y, batch_size, length, augmenter=None, keras_kwargs=None):
super(DataWrapper, self).__init__(data_size=len(X), batch_size=batch_size, length=length, shuffle=True, keras_kwargs=keras_kwargs)
len(X) == len(Y) or _raise(ValueError("X and Y must have same length"))
self.X, self.Y = X, Y
self.augmenter = augmenter
def __getitem__(self, i):
idx = self.batch(i)
X, Y = self.X[idx], self.Y[idx]
if self.augmenter is not None:
X,Y = tuple(zip(*tuple(self.augmenter(x,y) for x,y in zip(X,Y))))
X,Y = np.stack(X), np.stack(Y)
return X,Y