Source code for csbdeep.models.care_projection

# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals, absolute_import, division

import numpy as np
from collections import namedtuple

from ..utils.tf import keras_import, BACKEND as K
Model = keras_import('models', 'Model')
Input, Conv3D, MaxPooling3D, UpSampling3D, Lambda, Multiply = keras_import('layers', 'Input', 'Conv3D', 'MaxPooling3D', 'UpSampling3D', 'Lambda', 'Multiply')
softmax = keras_import('activations', 'softmax')

from .care_standard import CARE
from .config import Config
from ..utils import _raise, axes_dict, axes_check_and_normalize
from ..internals import nets
from ..internals.predict import tile_overlap


class ProjectionConfig(Config):

    def __init__(self, axes='ZYX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs):
        super(ProjectionConfig, self).__init__(axes, n_channel_in, n_channel_out, probabilistic)
        ax = axes_dict(self.axes)
        self.proj_axis              = kwargs.get('proj_axis', 'Z')
        self.proj_n_depth           = 4
        self.proj_n_filt            = 8
        self.proj_n_conv_per_depth  = 1
        self.proj_kern              = tuple(3 if d==ax[self.proj_axis] else 3 for d in range(3))
        self.proj_pool              = tuple(1 if d==ax[self.proj_axis] else 2 for d in range(3))
        self.update_parameters(allow_new_parameters, **kwargs)



[docs]class ProjectionCARE(CARE): """CARE network for combined image restoration and projection of one dimension.""" @property def proj_params(self): assert self.config is not None try: return self._proj_params except AttributeError: # TODO: no need to be so cautious here, since there's now a dedicated ProjectionConfig class p = {} p['axis'] = vars(self.config).get('proj_axis', 'Z') p['n_depth'] = int(vars(self.config).get('proj_n_depth', 4)) p['n_filt'] = int(vars(self.config).get('proj_n_filt', 8)) p['n_conv_per_depth'] = int(vars(self.config).get('proj_n_conv_per_depth', 1)) p['axis'] = axes_check_and_normalize(p['axis'],length=1) ax = axes_dict(self.config.axes) len(self.config.axes) == 4 or _raise(ValueError("model must take 3D input, but axes are {self.config.axes}.".format(self=self))) ax[p['axis']] is not None or _raise(ValueError("projection axis {axis} not part of model axes {self.config.axes}".format(self=self,axis=p['axis']))) self.config.axes[-1] == 'C' or _raise(ValueError()) (p['n_depth'] > 0 and p['n_filt'] > 0 and p['n_conv_per_depth'] > 0) or _raise(ValueError()) p['kern'] = tuple(vars(self.config).get('proj_kern', (3 if d==ax[p['axis']] else 3 for d in range(3)))) p['pool'] = tuple(vars(self.config).get('proj_pool', (1 if d==ax[p['axis']] else 2 for d in range(3)))) 3 == len(p['pool']) == len(p['kern']) or _raise(ValueError()) all(isinstance(v,int) and v > 0 for v in p['kern']) or _raise(ValueError()) all(isinstance(v,int) and v > 0 for v in p['pool']) or _raise(ValueError()) self._proj_params = namedtuple('ProjectionParameters',p.keys())(*p.values()) return self._proj_params def _repr_extra(self): return "├─ {self.proj_params}\n".format(self=self) def _update_and_check_config(self): assert self.config is not None for k,v in self.proj_params._asdict().items(): setattr(self.config, 'proj_'+k, v) def _build(self): # get parameters proj = self.proj_params proj_axis = axes_dict(self.config.axes)[proj.axis] # define surface projection network (3D -> 2D) inp = u = Input(self.config.unet_input_shape) def conv_layers(u): for _ in range(proj.n_conv_per_depth): u = Conv3D(proj.n_filt, proj.kern, padding='same', activation='relu')(u) return u # down for _ in range(proj.n_depth): u = conv_layers(u) u = MaxPooling3D(proj.pool)(u) # middle u = conv_layers(u) # up for _ in range(proj.n_depth): u = UpSampling3D(proj.pool)(u) u = conv_layers(u) u = Conv3D(1, proj.kern, padding='same', activation='linear')(u) # convert learned features along Z to surface probabilities # (add 1 to proj_axis because of batch dimension in tensorflow) u = Lambda(lambda x: softmax(x, axis=1+proj_axis))(u) # multiply Z probabilities with Z values in input stack u = Multiply()([inp, u]) # perform surface projection by summing over weighted Z values u = Lambda(lambda x: K.sum(x, axis=1+proj_axis))(u) model_projection = Model(inp, u) # define denoising network (2D -> 2D) # (remove projected axis from input_shape) input_shape = list(self.config.unet_input_shape) del input_shape[proj_axis] model_denoising = nets.common_unet( n_dim = self.config.n_dim-1, n_channel_out = self.config.n_channel_out, prob_out = self.config.probabilistic, residual = self.config.unet_residual, n_depth = self.config.unet_n_depth, kern_size = self.config.unet_kern_size, n_first = self.config.unet_n_first, last_activation = self.config.unet_last_activation, )(tuple(input_shape)) # chain models together return Model(inp, model_denoising(model_projection(inp))) def train(self, X,Y, validation_data, **kwargs): proj_axis = self.proj_params.axis proj_axis = 1+axes_dict(self.config.axes)[proj_axis] Y.shape[proj_axis] == 1 or _raise(ValueError()) Y = np.take(Y,0,axis=proj_axis) try: X_val, Y_val = validation_data # Y_val.shape[proj_axis] == 1 or _raise(ValueError()) validation_data = X_val, np.take(Y_val,0,axis=proj_axis) except: pass return super(ProjectionCARE, self).train(X,Y, validation_data, **kwargs) def _axes_div_by(self, query_axes): query_axes = axes_check_and_normalize(query_axes) proj = self.proj_params div_by = { a : max(a_proj_pool**proj.n_depth, 1 if a==proj.axis else 2**self.config.unet_n_depth) for a,a_proj_pool in zip(self.config.axes.replace('C',''),proj.pool) } return tuple(div_by.get(a,1) for a in query_axes) def _axes_tile_overlap(self, query_axes): query_axes = axes_check_and_normalize(query_axes) proj = self.proj_params unet_overlap = tile_overlap(self.config.unet_n_depth, self.config.unet_kern_size) overlap = { a : max(tile_overlap(proj.n_depth, a_proj_kern, a_proj_pool), unet_overlap) # approx for a,a_proj_pool,a_proj_kern in zip(self.config.axes.replace('C',''),proj.pool,proj.kern) if a != proj.axis } return tuple(overlap.get(a,0) for a in query_axes) @property def _axes_out(self): return ''.join(a for a in self.config.axes if a != self.proj_params.axis) @property def _config_class(self): return ProjectionConfig