Source code for csbdeep.internals.nets

from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter

from ..utils.tf import keras_import
Input, Conv2D, Conv3D, Activation, Lambda, Add, Concatenate = keras_import('layers', 'Input', 'Conv2D', 'Conv3D', 'Activation', 'Lambda', 'Add', 'Concatenate')
Model = keras_import('models', 'Model')
from .blocks import unet_block
import re

from ..utils import _raise, backend_channels_last
import numpy as np


def custom_unet(input_shape,
                last_activation,
                n_depth=2,
                n_filter_base=16,
                kernel_size=(3,3,3),
                n_conv_per_depth=2,
                activation="relu",
                batch_norm=False,
                dropout=0.0,
                pool_size=(2,2,2),
                n_channel_out=1,
                residual=False,
                prob_out=False,
                eps_scale=1e-3):
    """ TODO """

    if last_activation is None:
        raise ValueError("last activation has to be given (e.g. 'sigmoid', 'relu')!")

    all((s % 2 == 1 for s in kernel_size)) or _raise(ValueError('kernel size should be odd in all dimensions.'))

    channel_axis = -1 if backend_channels_last() else 1

    n_dim = len(kernel_size)
    conv = Conv2D if n_dim==2 else Conv3D

    input = Input(input_shape, name = "input")
    unet = unet_block(n_depth, n_filter_base, kernel_size,
                      activation=activation, dropout=dropout, batch_norm=batch_norm,
                      n_conv_per_depth=n_conv_per_depth, pool=pool_size)(input)

    final = conv(n_channel_out, (1,)*n_dim, activation='linear')(unet)
    if residual:
        if not (n_channel_out == input_shape[-1] if backend_channels_last() else n_channel_out == input_shape[0]):
            raise ValueError("number of input and output channels must be the same for a residual net.")
        final = Add()([final, input])
    final = Activation(activation=last_activation)(final)

    if prob_out:
        scale = conv(n_channel_out, (1,)*n_dim, activation='softplus')(unet)
        scale = Lambda(lambda x: x+np.float32(eps_scale))(scale)
        final = Concatenate(axis=channel_axis)([final,scale])

    return Model(inputs=input, outputs=final)



[docs]def common_unet(n_dim=2, n_depth=1, kern_size=3, n_first=16, n_channel_out=1, residual=True, prob_out=False, last_activation='linear'): """Construct a common CARE neural net based on U-Net [1]_ and residual learning [2]_ to be used for image restoration/enhancement. Parameters ---------- n_dim : int number of image dimensions (2 or 3) n_depth : int number of resolution levels of U-Net architecture kern_size : int size of convolution filter in all image dimensions n_first : int number of convolution filters for first U-Net resolution level (value is doubled after each downsampling operation) n_channel_out : int number of channels of the predicted output image residual : bool if True, model will internally predict the residual w.r.t. the input (typically better) requires number of input and output image channels to be equal prob_out : bool standard regression (False) or probabilistic prediction (True) if True, model will predict two values for each input pixel (mean and positive scale value) last_activation : str name of activation function for the final output layer Returns ------- function Function to construct the network, which takes as argument the shape of the input image Example ------- >>> model = common_unet(2, 1,3,16, 1, True, False)(input_shape) References ---------- .. [1] Olaf Ronneberger, Philipp Fischer, Thomas Brox, *U-Net: Convolutional Networks for Biomedical Image Segmentation*, MICCAI 2015 .. [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. *Deep Residual Learning for Image Recognition*, CVPR 2016 """ def _build_this(input_shape): return custom_unet(input_shape, last_activation, n_depth, n_first, (kern_size,)*n_dim, pool_size=(2,)*n_dim, n_channel_out=n_channel_out, residual=residual, prob_out=prob_out) return _build_this
modelname = re.compile(r"^(?P<model>resunet|unet)(?P<n_dim>\d)(?P<prob_out>p)?_(?P<n_depth>\d+)_(?P<kern_size>\d+)_(?P<n_first>\d+)(_(?P<n_channel_out>\d+)out)?(_(?P<last_activation>.+)-last)?$") def common_unet_by_name(model): r"""Shorthand notation for equivalent use of :func:`common_unet`. Parameters ---------- model : str define model to be created via string, which is parsed as a regular expression: `^(?P<model>resunet|unet)(?P<n_dim>\d)(?P<prob_out>p)?_(?P<n_depth>\d+)_(?P<kern_size>\d+)_(?P<n_first>\d+)(_(?P<n_channel_out>\d+)out)?(_(?P<last_activation>.+)-last)?$` Returns ------- function Calls :func:`common_unet` with the respective parameters. Raises ------ ValueError If argument `model` is not a valid string according to the regular expression. Example ------- >>> model = common_unet_by_name('resunet2_1_3_16_1out')(input_shape) >>> # equivalent to: model = common_unet(2, 1,3,16, 1, True, False)(input_shape) Todo ---- Backslashes in docstring for regexp not rendered correctly. """ m = modelname.fullmatch(model) if m is None: raise ValueError("model name '%s' unknown, must follow pattern '%s'" % (model, modelname.pattern)) # from pprint import pprint # pprint(m.groupdict()) options = {k:int(m.group(k)) for k in ['n_depth','n_first','kern_size']} options['prob_out'] = m.group('prob_out') is not None options['residual'] = {'unet': False, 'resunet': True}[m.group('model')] options['n_dim'] = int(m.group('n_dim')) options['n_channel_out'] = 1 if m.group('n_channel_out') is None else int(m.group('n_channel_out')) if m.group('last_activation') is not None: options['last_activation'] = m.group('last_activation') return common_unet(**options) def receptive_field_unet(n_depth, kern_size, pool_size=2, n_dim=2, img_size=1024): """Receptive field for U-Net model (pre/post for each dimension).""" x = np.zeros((1,)+(img_size,)*n_dim+(1,)) mid = tuple([s//2 for s in x.shape[1:-1]]) x[(slice(None),) + mid + (slice(None),)] = 1 model = custom_unet ( x.shape[1:], n_depth=n_depth, kernel_size=[kern_size]*n_dim, pool_size=[pool_size]*n_dim, n_filter_base=8, activation='linear', last_activation='linear', ) y = model.predict(x)[0,...,0] y0 = model.predict(0*x)[0,...,0] ind = np.where(np.abs(y-y0)>0) return [(m-np.min(i),np.max(i)-m) for (m,i) in zip(mid,ind)]