Model training

Given suitable training data (see Training data generation), we can define and train a CARE model to restore the source data. To that end, we first need to specify all the options of the model by creating a configuration object via csbdeep.models.Config. Note that we provide sensible default configuration options that should work in many cases. However, you can overwrite them via keyword arguments.

Please see Model overview to choose among the supported restoration models. While training data generation and prediction typically differs among the models, note that the training process is mostly the same for all models. For example, a standard (denoising) CARE model can be instantiated via csbdeep.models.CARE and then trained with the csbdeep.models.CARE.train() method. After training, the learned model can be exported via csbdeep.models.CARE.export_TF() to be used with our Fiji Plugin.


>>> from import load_training_data
>>> from csbdeep.models import Config, CARE
>>> (X,Y), (X_val,Y_val), axes = load_training_data('my_data.npz', validation_split=0.1)
>>> config = Config(axes)
>>> model = CARE(config, 'my_model')
>>> model.train(X,Y, validation_data=(X_val,Y_val))
>>> model.export_TF()
class csbdeep.models.Config(axes='YX', n_channel_in=1, n_channel_out=1, probabilistic=False, allow_new_parameters=False, **kwargs)[source]

Default configuration for a CARE model.

This configuration is meant to be used with CARE and related models (e.g., IsotropicCARE).

  • axes (str) – Axes of the neural network (channel axis optional).
  • n_channel_in (int) – Number of channels of given input image.
  • n_channel_out (int) – Number of channels of predicted output image.
  • probabilistic (bool) – Probabilistic prediction of per-pixel Laplace distributions or typical regression of per-pixel scalar values.
  • allow_new_parameters (bool) – Allow adding new configuration attributes (i.e. not listed below).
  • kwargs (dict) – Overwrite (or add) configuration attributes (see below).


>>> config = Config('YX', probabilistic=True, unet_n_depth=3)

int – Dimensionality of input images (2 or 3).


bool – Parameter residual of csbdeep.nets.common_unet(). Default: n_channel_in == n_channel_out


int – Parameter n_depth of csbdeep.nets.common_unet(). Default: 2


int – Parameter kern_size of csbdeep.nets.common_unet(). Default: 5 if n_dim==2 else 3


int – Parameter n_first of csbdeep.nets.common_unet(). Default: 32


str – Parameter last_activation of csbdeep.nets.common_unet(). Default: linear


str – Name of training loss. Default: 'laplace' if probabilistic else 'mae'


int – Number of training epochs. Default: 100


int – Number of parameter update steps per epoch. Default: 400


float – Learning rate for training. Default: 0.0004


int – Batch size for training. Default: 16


bool – Enable TensorBoard for monitoring training progress. Default: True


str – Name of checkpoint file for model weights (only best are saved); set to None to disable. Default: weights_best.h5


dict – Parameter dict of ReduceLROnPlateau callback; set to None to disable. Default: {'factor': 0.5, 'patience': 10, 'min_delta': 0}


Check if configuration is valid.

Returns:Flag that indicates whether the current configuration values are valid.
Return type:bool
class csbdeep.models.CARE(config, name=None, basedir='.')[source]

Standard CARE network for image restoration and enhancement.

Uses a convolutional neural network created by csbdeep.internals.nets.common_unet(). Note that isotropic reconstruction and manifold extraction/projection are not supported here (see csbdeep.models.IsotropicCARE ).

  • config (csbdeep.models.Config or None) – Valid configuration of CARE network (see Config.is_valid()). Will be saved to disk as JSON (config.json). If set to None, will be loaded from disk (must exist).
  • name (str or None) – Model name. Uses a timestamp if set to None (default).
  • basedir (str) – Directory that contains (or will contain) a folder with the given model name. Use None to disable saving (or loading) any data to (or from) disk (regardless of other parameters).
  • FileNotFoundError – If config=None and config cannot be loaded from disk.
  • ValueError – Illegal arguments, including invalid configuration.


>>> model = CARE(config, 'my_model')

csbdeep.models.Config – Configuration of CARE network, as provided during instantiation.


Keras model – Keras neural network model.


str – Model name.


pathlib.Path – Path to model folder (which stores configuration, weights, etc.)


Export neural network via

Parameters:fname (str or None) – Path of the created SavedModel archive (will end with “.zip”). If None, “<model-directory>/” will be used.
predict(img, axes, normalizer=< object>, resizer=< object>, n_tiles=None)[source]

Apply neural network to raw image to predict restored image.

  • img (numpy.ndarray) – Raw input image
  • axes (str) – Axes of the input img.
  • normalizer ( or None) – Normalization of input image before prediction and (potentially) transformation back after prediction.
  • resizer ( or None) – If necessary, input image is resized to enable neural network prediction and result is (possibly) resized to yield original image size.
  • n_tiles (iterable or None) – Out of memory (OOM) errors can occur if the input image is too large. To avoid this problem, the input image is broken up into (overlapping) tiles that can then be processed independently and re-assembled to yield the restored image. This parameter denotes a tuple of the number of tiles for every image axis. Note that if the number of tiles is too low, it is adaptively increased until OOM errors are avoided, albeit at the expense of runtime. A value of None denotes that no tiling should initially be used.

Returns the restored image. If the model is probabilistic, this denotes the mean parameter of the predicted per-pixel Laplace distributions (i.e., the expected restored image). Axes semantics are the same as in the input image. Only if the output is multi-channel and the input image didn’t have a channel axis, then output channels are appended at the end.

Return type:


predict_probabilistic(img, axes, normalizer=< object>, resizer=< object>, n_tiles=None)[source]

Apply neural network to raw image to predict probability distribution for restored image.

See predict() for parameter explanations.

Returns:Returns the probability distribution of the restored image.
Return type:csbdeep.internals.probability.ProbabilisticPrediction
Raises:ValueError – If this is not a probabilistic model.
prepare_for_training(optimizer=None, **kwargs)[source]

Prepare for neural network training.

Calls csbdeep.internals.train.prepare_model() and creates Keras Callbacks to be used for training.

Note that this method will be implicitly called once by train() (with default arguments) if not done so explicitly beforehand.

train(X, Y, validation_data, epochs=None, steps_per_epoch=None)[source]

Train the neural network with the given data.

  • X (numpy.ndarray) – Array of source images.
  • Y (numpy.ndarray) – Array of target images.
  • validation_data (tuple(numpy.ndarray, numpy.ndarray)) – Tuple of arrays for source and target validation images.
  • epochs (int) – Optional argument to use instead of the value from config.
  • steps_per_epoch (int) – Optional argument to use instead of the value from config.

See Keras training history.

Return type:

History object

Supporting functions:, validation_split=0, axes=None, n_images=None, verbose=False)[source]

Load training data from file in .npz format.

The data file is expected to have the keys:

  • X : Array of training input images.
  • Y : Array of corresponding target images.
  • axes : Axes of the training images.
  • file (str) – File name
  • validation_split (float) – Fraction of images to use as validation set during training.
  • axes (str, optional) – Must be provided in case the loaded data does not contain axes information.
  • n_images (int, optional) – Can be used to limit the number of images loaded from data.
  • verbose (bool, optional) – Can be used to display information about the loaded images.

Returns two tuples (X_train, Y_train), (X_val, Y_val) of training and validation sets and the axes of the input images. The tuple of validation data will be None if validation_split = 0.

Return type:

tuple( tuple(numpy.ndarray, numpy.ndarray), tuple(numpy.ndarray, numpy.ndarray), str )

csbdeep.internals.nets.common_unet(n_dim=2, n_depth=1, kern_size=3, n_first=16, n_channel_out=1, residual=True, prob_out=False, last_activation='linear')[source]

Construct a common CARE neural net based on U-Net [1] and residual learning [2] to be used for image restoration/enhancement.

  • n_dim (int) – number of image dimensions (2 or 3)
  • n_depth (int) – number of resolution levels of U-Net architecture
  • kern_size (int) – size of convolution filter in all image dimensions
  • n_first (int) – number of convolution filters for first U-Net resolution level (value is doubled after each downsampling operation)
  • n_channel_out (int) – number of channels of the predicted output image
  • residual (bool) – if True, model will internally predict the residual w.r.t. the input (typically better) requires number of input and output image channels to be equal
  • prob_out (bool) – standard regression (False) or probabilistic prediction (True) if True, model will predict two values for each input pixel (mean and positive scale value)
  • last_activation (str) – name of activation function for the final output layer

Function to construct the network, which takes as argument the shape of the input image

Return type:



>>> model = common_unet(2, 1,3,16, 1, True, False)(input_shape)


[1]Olaf Ronneberger, Philipp Fischer, Thomas Brox, U-Net: Convolutional Networks for Biomedical Image Segmentation, MICCAI 2015
[2]Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Deep Residual Learning for Image Recognition, CVPR 2016
csbdeep.internals.train.prepare_model(model, optimizer, loss, metrics=('mse', 'mae'), loss_bg_thresh=0, loss_bg_decay=0.06, Y=None)[source]

TODO, outpath, meta={}, format='zip')[source]

Export Keras model in TensorFlow’s SavedModel format.

See Your Model in Fiji to learn how to use the exported model with our CSBDeep Fiji plugins.

  • model (keras.models.Model) – Keras model to be exported.
  • outpath (str) – Path of the file/folder that the model will exported to.
  • meta (dict, optional) – Metadata to be saved in an additional meta.json file.
  • format (str, optional) – Can be ‘dir’ to export as a directory or ‘zip’ (default) to export as a ZIP file.

ValueError – Illegal arguments.

Other models

Training other CARE models (csbdeep.models.IsotropicCARE, csbdeep.models.UpsamplingCARE, csbdeep.models.ProjectionCARE) currently does not differ from that of a standard model. What changes is the way in which the training data is generated (see Training data generation).