from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
from scipy.ndimage import zoom
from csbdeep.internals.probability import ProbabilisticPrediction
from .care_standard import CARE
from ..internals.predict import predict_direct
from ..data import PercentileNormalizer, PadAndCropResizer
from ..utils import _raise, axes_check_and_normalize
[docs]class IsotropicCARE(CARE):
"""CARE network for isotropic image reconstruction.
Extends :class:`csbdeep.models.CARE` by replacing prediction
(:func:`predict`, :func:`predict_probabilistic`) to do isotropic reconstruction.
"""
[docs] def predict(self, img, axes, factor, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), batch_size=8):
"""Apply neural network to raw image for isotropic reconstruction.
See :func:`CARE.predict` for documentation.
Parameters
----------
factor : float
Upsampling factor for Z axis. It is important that this is chosen in correspondence
to the subsampling factor used during training data generation.
batch_size : int
Number of image slices that are processed together by the neural network.
Reduce this value if out of memory errors occur.
"""
return self._predict_mean_and_scale(img, axes, factor, normalizer, resizer, batch_size)[0]
[docs] def predict_probabilistic(self, img, axes, factor, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), batch_size=8):
"""Apply neural network to raw image to predict probability distribution for isotropic restored image.
See :func:`CARE.predict_probabilistic` for documentation.
Parameters
----------
factor : float
Upsampling factor for Z axis. It is important that this is chosen in correspondence
to the subsampling factor used during training data generation.
batch_size : int
Number of image slices that are processed together by the neural network.
Reduce this value if out of memory errors occur.
"""
self.config.probabilistic or _raise(ValueError('This is not a probabilistic model.'))
mean, scale = self._predict_mean_and_scale(img, axes, factor, normalizer, resizer, batch_size)
return ProbabilisticPrediction(mean, scale)
def _predict_mean_and_scale(self, img, axes, factor, normalizer, resizer, batch_size):
"""Apply neural network to raw image to restore isotropic resolution.
See :func:`predict` for parameter explanations.
Returns
-------
tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray` or None)
If model is probabilistic, returns a tuple `(mean, scale)` that defines the parameters
of per-pixel Laplace distributions. Otherwise, returns the restored image via a tuple `(restored,None)`
"""
normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
axes = axes_check_and_normalize(axes,img.ndim)
'Z' in axes or _raise(ValueError())
axes_tmp = 'CZ' + axes.replace('Z','').replace('C','')
_permute_axes = self._make_permute_axes(axes, axes_tmp)
channel = 0
x = _permute_axes(img)
self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
np.isscalar(factor) and factor > 0 or _raise(ValueError())
def scale_z(arr,factor):
return zoom(arr,(1,factor,1,1),order=1)
# normalize
x = normalizer.before(x,axes_tmp)
# scale z up (second axis)
x_scaled = scale_z(x,factor)
# resize: make (x,y,z) image dimensions divisible by power of 2 to allow downsampling steps in unet
x_scaled = resizer.before(x_scaled, axes_tmp, self._axes_div_by(axes_tmp))
# move channel to the end (axes_predict semantics)
x_scaled = np.moveaxis(x_scaled, channel, -1)
axes_predict = 'S' + axes_tmp[2:] + 'C'
channel = -1
# u1: first rotation and prediction
x_rot1 = self._rotate(x_scaled, axis=1, copy=False)
u_rot1 = predict_direct(self.keras_model, x_rot1, axes_predict, batch_size=batch_size, verbose=0)
u1 = self._rotate(u_rot1, -1, axis=1, copy=False)
# u2: second rotation and prediction
x_rot2 = self._rotate(self._rotate(x_scaled, axis=2, copy=False), axis=0, copy=False)
u_rot2 = predict_direct(self.keras_model, x_rot2, axes_predict, batch_size=batch_size, verbose=0)
u2 = self._rotate(self._rotate(u_rot2, -1, axis=0, copy=False), -1, axis=2, copy=False)
n_channel_predicted = self.config.n_channel_out * (2 if self.config.probabilistic else 1)
u_rot1.shape[channel] == n_channel_predicted or _raise(ValueError())
u_rot2.shape[channel] == n_channel_predicted or _raise(ValueError())
# move channel back to the front (axes_tmp semantics)
u1 = np.moveaxis(u1, channel, 0)
u2 = np.moveaxis(u2, channel, 0)
channel = 0
# resize after prediction
u1 = resizer.after(u1, axes_tmp)
u2 = resizer.after(u2, axes_tmp)
# combine u1 & u2
mean1, scale1 = self._mean_and_scale_from_prediction(u1,axis=channel)
mean2, scale2 = self._mean_and_scale_from_prediction(u2,axis=channel)
# avg = lambda u1,u2: (u1+u2)/2 # arithmetic mean
avg = lambda u1,u2: np.sqrt(np.maximum(u1,0)*np.maximum(u2,0)) # geometric mean
mean, scale = avg(mean1,mean2), None
if self.config.probabilistic:
scale = np.maximum(scale1,scale2)
if normalizer.do_after and self.config.n_channel_in==self.config.n_channel_out:
mean, scale = normalizer.after(mean, scale, axes_tmp)
mean, scale = _permute_axes(mean,undo=True), _permute_axes(scale,undo=True)
return mean, scale
@staticmethod
def _rotate(arr, k=1, axis=1, copy=True):
"""Rotate by 90 degrees around the first 2 axes."""
if copy:
arr = arr.copy()
k = k % 4
arr = np.rollaxis(arr, axis, arr.ndim)
if k == 0:
res = arr
elif k == 1:
res = arr[::-1].swapaxes(0, 1)
elif k == 2:
res = arr[::-1, ::-1]
else:
res = arr.swapaxes(0, 1)[::-1]
res = np.rollaxis(res, -1, axis)
return res