Source code for

from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter

import numpy as np
import os
import warnings
import shutil
import datetime
from importlib import import_module
from packaging import version

from tensorflow import __version__ as _v_tf
v_tf = version.parse(_v_tf)
IS_TF_1 = v_tf.major == 1
IS_TF_2_6_PLUS = v_tf >= version.parse('2.6')
IS_TF_2_16_PLUS = v_tf >= version.parse('2.16')

def err_old_keras(v_keras, package='keras', min_version=None):
    min_version = "{v_tf.major}.{v_tf.minor}".format(v_tf=v_tf) if min_version is None else min_version
    return RuntimeError(\
Found version {v_keras} of '{package}', which is too old for the installed version {v_tf} of 'tensorflow'.
Please update '{package}': pip install "{package}>={min_version}"
""".format(v_keras=v_keras, v_tf=v_tf, package=package, min_version=min_version))

if IS_TF_1:
        from keras import __version__ as _v_keras
        v_keras = version.parse(_v_keras)
    except ModuleNotFoundError:
        raise RuntimeError(\
For 'tensorflow' 1.x (found version {v_tf}), the stand-alone 'keras' package (2.1.2 <= version < 2.4) is required.
When using the most recent version of 'tensorflow' 1.x, install 'keras' like this: pip install "keras>=2.1.2,<2.4"

If you must use an older version of 'tensorflow' 1.x, please see this file for compatible versions of 'keras':

elif IS_TF_2_16_PLUS:
    # -> "Moving from Keras 2 to Keras 3"
    if os.environ.get('TF_USE_LEGACY_KERAS','0') == '1':
            from tf_keras import __version__ as _v_keras
            v_keras = version.parse(_v_keras)
            assert v_keras.major < 3
            if (v_tf.major,v_tf.minor) > (v_keras.major,v_keras.minor):
                raise err_old_keras(v_keras, package='tf_keras')
            warnings.warn("Using Keras 2 (via 'tf_keras' package) with Tensorflow 2.16+ might work, but is not regularly tested.")
        except ModuleNotFoundError:
            raise RuntimeError("Found environment variable 'TF_USE_LEGACY_KERAS=1' but 'tf_keras' package not installed.")
            from keras import __version__ as _v_keras
            v_keras = version.parse(_v_keras)
            IS_KERAS_3_PLUS = v_keras.major >= 3
            if not IS_KERAS_3_PLUS:
                raise err_old_keras(v_keras, min_version='3')
        except ModuleNotFoundError:
            raise RuntimeError("Package 'keras' is not installed.")

elif IS_TF_2_6_PLUS:
        from keras import __version__ as _v_keras
        v_keras = version.parse(_v_keras)
        IS_KERAS_3_PLUS = v_keras.major >= 3
        if IS_KERAS_3_PLUS:
            warnings.warn("Using Keras version 3+ with Tensorflow below version 2.16 might work, but has not been tested.")
            if (v_tf.major,v_tf.minor) > (v_keras.major,v_keras.minor):
                raise err_old_keras(v_keras)
    except ModuleNotFoundError:
        raise RuntimeError(\
Starting with 'tensorflow' version 2.6.0, a recent version of the stand-alone 'keras' package is required.
Please install/update 'keras': pip install "keras>={v_tf.major}.{v_tf.minor}<3"

    from tensorflow.keras import __version__ as _v_keras
    v_keras = version.parse(_v_keras)
    assert v_keras.major < 3

_KERAS = 'keras' if (IS_TF_1 or IS_KERAS_3_PLUS) else 'tensorflow.keras'
def keras_import(sub=None, *names):
    if sub is None:
        return import_module(_KERAS)
        mod = import_module('{_KERAS}.{sub}'.format(_KERAS=_KERAS,sub=sub))
        if len(names) == 0:
            return mod
        elif len(names) == 1:
            return getattr(mod, names[0])
        return tuple(getattr(mod, name) for name in names)

import tensorflow as tf
# if IS_TF_1:
#     import tensorflow as tf
# else:
#     import tensorflow.compat.v1 as tf
#     # tf.disable_v2_behavior()

keras = keras_import()
K = keras_import('backend')
Callback = keras_import('callbacks', 'Callback')
Lambda = keras_import('layers', 'Lambda')

from .utils import _raise, is_tf_backend, save_json, backend_channels_last, normalize
from .six import tempfile

    OPS = keras_import('ops')
    _remap = dict(tf=tf, int_shape=OPS.shape)
    class Backend(object):
        def __getattr__(self, name):
            if name in _remap:
                return _remap[name]
            elif hasattr(K, name):
                return getattr(K, name)
                return getattr(OPS, name)
    _remap = dict(tf=tf)
    class Backend(object):
        def __getattr__(self, name):
            if name in _remap:
                return _remap[name]
                return getattr(K, name)
BACKEND = Backend()

def limit_gpu_memory(fraction, allow_growth=False, total_memory=None):
    """Limit GPU memory allocation for TensorFlow (TF) backend.

    fraction : float
        Limit TF to use only a fraction (value between 0 and 1) of the available GPU memory.
        Reduced memory allocation can be disabled if fraction is set to ``None``.
    allow_growth : bool, optional
        If ``False`` (default), TF will allocate all designated (see `fraction`) memory all at once.
        If ``True``, TF will allocate memory as needed up to the limit imposed by `fraction`; this may
        incur a performance penalty due to memory fragmentation.
    total_memory :  int or iterable of int
        Total amount of available GPU memory (in MB).

        If `fraction` is not ``None`` or a float value between 0 and 1.
        If TensorFlow is not used as the backend.

    is_tf_backend() or _raise(NotImplementedError('Not using tensorflow backend.'))
    fraction is None or (np.isscalar(fraction) and 0<=fraction<=1) or _raise(ValueError('fraction must be between 0 and 1.'))

    if IS_TF_1:
        _session = None
            _session = K.tensorflow_backend._SESSION
        except AttributeError:

        if _session is None:
            config = tf.ConfigProto()
            if fraction is not None:
                config.gpu_options.per_process_gpu_memory_fraction = fraction
            config.gpu_options.allow_growth = bool(allow_growth)
            session = tf.Session(config=config)
            # print("[tf_limit]\t setting config.gpu_options.per_process_gpu_memory_fraction to ",config.gpu_options.per_process_gpu_memory_fraction)
            warnings.warn('Too late to limit GPU memory, can only be done once and before any computation.')
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            if fraction is not None:
                np.isscalar(total_memory) or _raise(ValueError("'total_memory' must be provided when using TensorFlow 2."))
                vdc = tf.config.experimental.VirtualDeviceConfiguration(memory_limit=int(np.ceil(total_memory*fraction)))
                for gpu in gpus:
                    if fraction is not None:
                    if allow_growth:
                        tf.config.experimental.set_memory_growth(gpu, True)
            except RuntimeError as e:
                # must be set before GPUs have been initialized

[docs]def export_SavedModel(model, outpath, meta={}, format='zip'): """Export Keras model in TensorFlow's SavedModel_ format. See `Your Model in Fiji`_ to learn how to use the exported model with our CSBDeep Fiji plugins. .. _SavedModel: .. _`Your Model in Fiji`: Parameters ---------- model : :class:`keras.models.Model` Keras model to be exported. outpath : str Path of the file/folder that the model will exported to. meta : dict, optional Metadata to be saved in an additional ``meta.json`` file. format : str, optional Can be 'dir' to export as a directory or 'zip' (default) to export as a ZIP file. Raises ------ ValueError Illegal arguments. """ if IS_KERAS_3_PLUS: raise NotImplementedError(\ """Exporting to SavedModel is no longer supported with Keras 3+. It is likely that the exported model *will only work* in associated ImageJ/Fiji plugins (e.g. CSBDeep and StarDist) when using 'tensorflow' 1.x to export the model. The current workaround is to load the trained model in a Python environment with installed 'tensorflow' 1.x and then export it again. If you need help with this, please read: """) def export_to_dir(dirname): if len(model.inputs) > 1 or len(model.outputs) > 1: warnings.warn('Found multiple input or output layers.') def _export(model): if IS_TF_1: from tensorflow import saved_model from keras.backend import get_session else: from tensorflow.compat.v1 import saved_model from tensorflow.compat.v1.keras.backend import get_session if not IS_TF_1: warnings.warn(\ """ ***IMPORTANT NOTE*** You are using 'tensorflow' 2.x, hence it is likely that the exported model *will not work* in associated ImageJ/Fiji plugins (e.g. CSBDeep and StarDist). If you indeed have problems loading the exported model in Fiji, the current workaround is to load the trained model in a Python environment with installed 'tensorflow' 1.x and then export it again. If you need help with this, please read: """) builder = saved_model.builder.SavedModelBuilder(dirname) # use name 'input'/'output' if there's just a single input/output layer inputs = dict(zip(model.input_names,model.inputs)) if len(model.inputs) > 1 else dict(input=model.input) outputs = dict(zip(model.output_names,model.outputs)) if len(model.outputs) > 1 else dict(output=model.output) signature = saved_model.signature_def_utils.predict_signature_def(inputs=inputs, outputs=outputs) signature_def_map = { saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature } builder.add_meta_graph_and_variables(get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map) if IS_TF_1: _export(model) else: from tensorflow.keras.models import clone_model weights = model.get_weights() with tf.Graph().as_default(): # clone model in new graph and set weights _model = clone_model(model) _model.set_weights(weights) _export(_model) if meta is not None and len(meta) > 0: save_json(meta, os.path.join(dirname,'meta.json')) ## checks isinstance(model,keras.models.Model) or _raise(ValueError("'model' must be a Keras model.")) # supported_formats = tuple(['dir']+[name for name,description in shutil.get_archive_formats()]) supported_formats = 'dir','zip' format in supported_formats or _raise(ValueError("Unsupported format '%s', must be one of %s." % (format,str(supported_formats)))) # remove '.zip' file name extension if necessary if format == 'zip' and outpath.endswith('.zip'): outpath = os.path.splitext(outpath)[0] if format == 'dir': export_to_dir(outpath) else: with tempfile.TemporaryDirectory() as tmpdir: tmpsubdir = os.path.join(tmpdir,'model') export_to_dir(tmpsubdir) shutil.make_archive(outpath, format, tmpsubdir)
def tf_normalize(x, pmin=1, pmax=99.8, axis=None, clip=False): assert pmin < pmax mi = tf.contrib.distributions.percentile(x,pmin, axis=axis, keep_dims=True) ma = tf.contrib.distributions.percentile(x,pmax, axis=axis, keep_dims=True) y = (x-mi)/(ma-mi+K.epsilon()) if clip: y = K.clip(y,0,1.0) return y def tf_normalize_layer(layer, pmin=1, pmax=99.8, clip=True): def norm(x,axis): return tf_normalize(x, pmin=pmin, pmax=pmax, axis=axis, clip=clip) shape = K.int_shape(layer) n_channels_out = shape[-1] n_dim_out = len(shape) if n_dim_out > 4: layer = Lambda(lambda x: K.max(x, axis=tuple(range(1,1+n_dim_out-4))))(layer) assert 0 < n_channels_out if n_channels_out == 1: out = Lambda(lambda x: norm(x, axis=(1,2)))(layer) elif n_channels_out == 2: out = Lambda(lambda x: norm(K.concatenate([x,x[...,:1]], axis=-1), axis=(1,2,3)))(layer) elif n_channels_out == 3: out = Lambda(lambda x: norm(x, axis=(1,2,3)))(layer) else: out = Lambda(lambda x: norm(K.max(x, axis=-1, keepdims=True), axis=(1,2,3)))(layer) return out class CARETensorBoard(Callback): """ TODO """ def __init__(self, log_dir='./logs', freq=1, compute_histograms=False, n_images=3, prob_out=False, write_graph=False, prefix_with_timestamp=True, write_images=False, image_for_inputs=None, # write images for only these input indices image_for_outputs=None, # write images for only these output indices input_slices=None, # list (of list) of slices to apply to `image_for_inputs` layers before writing image output_slices=None, # list (of list) of slices to apply to `image_for_outputs` layers before writing image output_target_shapes=None): # list of shapes of the target/gt images that correspond to all model outputs super(CARETensorBoard, self).__init__() is_tf_backend() or _raise(RuntimeError('TensorBoard callback only works with the TensorFlow backend.')) backend_channels_last() or _raise(NotImplementedError()) IS_TF_1 or _raise(NotImplementedError("Not supported with TensorFlow 2")) self.freq = freq self.image_freq = freq self.prob_out = prob_out self.merged = None self.gt_outputs = None self.write_graph = write_graph self.write_images = write_images self.n_images = n_images self.image_for_inputs = image_for_inputs self.image_for_outputs = image_for_outputs self.input_slices = input_slices self.output_slices = output_slices self.output_target_shapes = output_target_shapes self.compute_histograms = compute_histograms if prefix_with_timestamp: log_dir = os.path.join(log_dir,"%Y-%m-%d-%H-%M-%S.%f")) self.log_dir = log_dir def set_model(self, model): self.model = model self.sess = K.get_session() tf_sums = [] if self.compute_histograms and self.freq and self.merged is None: for layer in self.model.layers: for weight in layer.weights: tf_sums.append(tf.summary.histogram(, weight)) if hasattr(layer, 'output'): tf_sums.append(tf.summary.histogram('{}_out'.format(, layer.output)) def _gt_shape(output_shape,target_shape): if target_shape is not None: output_shape = target_shape if not self.prob_out: return output_shape output_shape[-1] % 2 == 0 or _raise(ValueError()) return list(output_shape[:-1]) + [output_shape[-1] // 2] n_inputs, n_outputs = len(self.model.inputs), len(self.model.outputs) image_for_inputs = np.arange(n_inputs) if self.image_for_inputs is None else self.image_for_inputs image_for_outputs = np.arange(n_outputs) if self.image_for_outputs is None else self.image_for_outputs output_target_shapes = [None]*n_outputs if self.output_target_shapes is None else self.output_target_shapes self.gt_outputs = [K.placeholder(shape=_gt_shape(K.int_shape(x),sh)) for x,sh in zip(self.model.outputs,output_target_shapes)] input_slices = (slice(None),) if self.input_slices is None else self.input_slices output_slices = (slice(None),) if self.output_slices is None else self.output_slices if isinstance(input_slices[0],slice): # apply same slices to all inputs input_slices = [input_slices]*len(image_for_inputs) if isinstance(output_slices[0],slice): # apply same slices to all outputs output_slices = [output_slices]*len(image_for_outputs) len(input_slices) == len(image_for_inputs) or _raise(ValueError()) len(output_slices) == len(image_for_outputs) or _raise(ValueError()) def _name(prefix, layer, i, n, show_layer_names=False): return '{prefix}{i}{name}'.format ( prefix = prefix, i = (i if n > 1 else ''), name = '' if (layer is None or not show_layer_names) else '_'+''.join(':')[:-1]), ) # inputs for i,sl in zip(image_for_inputs,input_slices): # print('input', self.model.inputs[i], tuple(sl)) layer_name = _name('net_input', self.model.inputs[i], i, n_inputs) input_layer = tf_normalize_layer(self.model.inputs[i][tuple(sl)]) tf_sums.append(tf.summary.image(layer_name, input_layer, max_outputs=self.n_images)) # outputs for i,sl in zip(image_for_outputs,output_slices): # print('output', self.model.outputs[i], tuple(sl)) output_shape = self.model.output_shape if n_outputs==1 else self.model.output_shape[i] # target output_layer = tf_normalize_layer(self.gt_outputs[i][tuple(sl)]) layer_name = _name('net_target', self.model.outputs[i], i, n_outputs) tf_sums.append(tf.summary.image(layer_name, output_layer, max_outputs=self.n_images)) # prediction n_channels_out = sep = output_shape[-1] if self.prob_out: # first half of output channels is mean, second half scale n_channels_out % 2 == 0 or _raise(ValueError()) sep = sep // 2 output_layer = tf_normalize_layer(self.model.outputs[i][...,:sep][tuple(sl)]) if self.prob_out: scale_layer = tf_normalize_layer(self.model.outputs[i][...,sep:][tuple(sl)], pmin=0, pmax=100) mean_name = _name('net_output_mean', self.model.outputs[i], i, n_outputs) scale_name = _name('net_output_scale', self.model.outputs[i], i, n_outputs) tf_sums.append(tf.summary.image(mean_name, output_layer, max_outputs=self.n_images)) tf_sums.append(tf.summary.image(scale_name, scale_layer, max_outputs=self.n_images)) else: layer_name = _name('net_output', self.model.outputs[i], i, n_outputs) tf_sums.append(tf.summary.image(layer_name, output_layer, max_outputs=self.n_images)) with tf.name_scope('merged'): self.merged = tf.summary.merge(tf_sums) with tf.name_scope('summary_writer'): if self.write_graph: self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) else: self.writer = tf.summary.FileWriter(self.log_dir) def on_epoch_end(self, epoch, logs=None): logs = logs or {} if self.validation_data and self.freq: if epoch % self.freq == 0: # TODO: implement batched calls to # (current call will likely go OOM on GPU) tensors = self.model.inputs + self.gt_outputs + self.model.sample_weights if self.model.uses_learning_phase: tensors += [K.learning_phase()] val_data = list(v[:self.n_images] for v in self.validation_data[:-1]) val_data += self.validation_data[-1:] else: val_data = list(v[:self.n_images] for v in self.validation_data) feed_dict = dict(zip(tensors, val_data)) result =[self.merged], feed_dict=feed_dict) summary_str = result[0] self.writer.add_summary(summary_str, epoch) for name, value in logs.items(): if name in ['batch', 'size']: continue summary = tf.Summary() summary_value = summary.value.add() summary_value.simple_value = float(value) summary_value.tag = name self.writer.add_summary(summary, epoch) self.writer.flush() def on_train_end(self, _): self.writer.close() class CARETensorBoardImage(Callback): def __init__(self, model, data, log_dir, n_images=3, batch_size=1, prob_out=False, image_for_inputs=None, # write images for only these input indices image_for_outputs=None, # write images for only these output indices input_slices=None, # list (of list) of slices to apply to `image_for_inputs` layers before writing image output_slices=None): # list (of list) of slices to apply to `image_for_outputs` layers before writing image super(CARETensorBoardImage, self).__init__() backend_channels_last() or _raise(NotImplementedError()) not IS_TF_1 or _raise(NotImplementedError("Not supported with TensorFlow 1")) ## model from ..models import BaseModel if isinstance(model, BaseModel): model = model.keras_model isinstance(model, keras.Model) or _raise(ValueError()) # self.model is already used by keras.callbacks.Callback in keras 3+ self.model_ = model self.n_inputs = len(self.model_.inputs) self.n_outputs = len(self.model_.outputs) ## data if isinstance(data,(list,tuple)): X, Y = data else: X, Y = data[0] if self.n_inputs == 1 and isinstance(X,np.ndarray): X = (X,) if self.n_outputs == 1 and isinstance(Y,np.ndarray): Y = (Y,) (self.n_inputs == len(X) and self.n_outputs == len(Y)) or _raise(ValueError()) # all(len(v)>=n_images for v in X) or _raise(ValueError()) # all(len(v)>=n_images for v in Y) or _raise(ValueError()) self.X = [v[:n_images] for v in X] self.Y = [v[:n_images] for v in Y] self.log_dir = log_dir self.batch_size = batch_size self.prob_out = prob_out ## I/O choices image_for_inputs = np.arange(self.n_inputs) if image_for_inputs is None else image_for_inputs image_for_outputs = np.arange(self.n_outputs) if image_for_outputs is None else image_for_outputs input_slices = (slice(None),) if input_slices is None else input_slices output_slices = (slice(None),) if output_slices is None else output_slices if isinstance(input_slices[0],slice): # apply same slices to all inputs input_slices = [input_slices]*len(image_for_inputs) if isinstance(output_slices[0],slice): # apply same slices to all outputs output_slices = [output_slices]*len(image_for_outputs) len(input_slices) == len(image_for_inputs) or _raise(ValueError()) len(output_slices) == len(image_for_outputs) or _raise(ValueError()) self.image_for_inputs = image_for_inputs self.image_for_outputs = image_for_outputs self.input_slices = input_slices self.output_slices = output_slices self.file_writer = tf.summary.create_file_writer(str(self.log_dir)) def _name(self, prefix, layer, i, n, show_layer_names=False): return '{prefix}{i}{name}'.format ( prefix = prefix, i = (i if n > 1 else ''), name = '' if (layer is None or not show_layer_names) else '_'+''.join(':')[:-1]), ) def _normalize_image(self, x, pmin=1, pmax=99.8, clip=True): def norm(x, axis): return normalize(x, pmin=pmin, pmax=pmax, axis=axis, clip=clip) n_channels_out = x.shape[-1] n_dim_out = len(x.shape) assert 0 < n_channels_out if n_dim_out > 4: x = np.max(x, axis=tuple(range(1,1+n_dim_out-4))) if n_channels_out == 1: out = norm(x, axis=(1,2)) elif n_channels_out == 2: out = norm(np.concatenate([x,x[...,:1]], axis=-1), axis=(1,2,3)) elif n_channels_out == 3: out = norm(x, axis=(1,2,3)) else: out = norm(np.max(x, axis=-1, keepdims=True), axis=(1,2,3)) return out def on_epoch_end(self, epoch, logs=None): # # inputs if epoch == 0: with self.file_writer.as_default(): for i,sl in zip(self.image_for_inputs,self.input_slices): # print('input', self.model_.inputs[i], tuple(sl)) input_name = self._name('net_input', self.model_.inputs[i], i, self.n_inputs) input_image = self._normalize_image(self.X[i][tuple(sl)]) tf.summary.image(input_name, input_image, step=epoch) # outputs Yhat = self.model_.predict(self.X, batch_size=self.batch_size, verbose=0) if self.n_outputs == 1 and isinstance(Yhat,np.ndarray): Yhat = (Yhat,) with self.file_writer.as_default(): for i,sl in zip(self.image_for_outputs,self.output_slices): # print('output', self.model_.outputs[i], tuple(sl)) output_shape = self.model_.output_shape if self.n_outputs==1 else self.model_.output_shape[i] n_channels_out = sep = output_shape[-1] if self.prob_out: # first half of output channels is mean, second half scale n_channels_out % 2 == 0 or _raise(ValueError()) sep = sep // 2 output_image = self._normalize_image(Yhat[i][...,:sep][tuple(sl)]) if self.prob_out: scale_image = self._normalize_image(Yhat[i][...,sep:][tuple(sl)], pmin=0, pmax=100) scale_name = self._name('net_output_scale', self.model_.outputs[i], i, self.n_outputs) output_name = self._name('net_output_mean', self.model_.outputs[i], i, self.n_outputs) tf.summary.image(output_name, output_image, step=epoch) tf.summary.image(scale_name, scale_image, step=epoch) else: output_name = self._name('net_output', self.model_.outputs[i], i, self.n_outputs) tf.summary.image(output_name, output_image, step=epoch) # targets if epoch == 0: with self.file_writer.as_default(): for i,sl in zip(self.image_for_outputs,self.output_slices): # print('target', self.model_.outputs[i], tuple(sl)) target_name = self._name('net_target', self.model_.outputs[i], i, self.n_outputs) target_image = self._normalize_image(self.Y[i][tuple(sl)]) tf.summary.image(target_name, target_image, step=epoch)