Training data generation

The details of training data generation vary depending on the intended image restoration task (see Model overview). We recommend to start with one of our examples.

We first explain the process for a standard CARE model. To that end, we need to specify matching pairs of raw source and target images, which is done with a object. It is important that you correctly set the axes of the raw images, e.g. to CYX for 2D images with a channel dimension before the two lateral dimensions.

Image Axes
X: columns, Y: rows, Z: planes, C: channels, T: frames/time, (S: samples/images)


  • The raw data should be representative of all images that the CARE network will potentially be applied to after training.
  • Usually, it is best to not process the raw images in any other way (e.g. deconvolution).
  • Source and target images must be well-aligned to obtain effective CARE networks.

collections.namedtuple() with three fields: generator, size, and description.

  • generator (function) – Function without arguments that returns a generator that yields tuples (x,y,axes,mask), where x is a source image (e.g., with low SNR) with y being the corresponding target image (e.g., with high SNR); mask can either be None or a boolean array that denotes which pixels are eligible to extracted in create_patches(). Note that x, y, and mask must all be of type numpy.ndarray and are assumed to have the same shape, where the string axes indicates the order and presence of axes of all three arrays.
  • size (int) – Number of tuples that the generator will yield.
  • description (str) – Textual description of the raw data.
static from_arrays(X, Y, axes='CZYX')[source]

Get pairs of corresponding images from numpy arrays.

static from_folder(basepath, source_dirs, target_dir, axes='CZYX', pattern='*.tif*')[source]

Get pairs of corresponding TIFF images read from folders.

Two images correspond to each other if they have the same file name, but are located in different folders.

  • basepath (str) – Base folder that contains sub-folders with images.
  • source_dirs (list or tuple) – List of folder names relative to basepath that contain the source images (e.g., with low SNR).
  • target_dir (str) – Folder name relative to basepath that contains the target images (e.g., with high SNR).
  • axes (str) – Semantics of axes of loaded images (assumed to be the same for all images).
  • pattern (str) – Glob-style pattern to match the desired TIFF images.

RawData object, whose generator is used to yield all matching TIFF pairs. The generator will return a tuple (x,y,axes,mask), where x is from source_dirs and y is the corresponding image from the target_dir; mask is set to None.

Return type:



FileNotFoundError – If an image found in a source_dir does not exist in target_dir.


>>> !tree data
├── GT
│   ├── imageA.tif
│   ├── imageB.tif
│   └── imageC.tif
├── source1
│   ├── imageA.tif
│   └── imageB.tif
└── source2
    ├── imageA.tif
    └── imageC.tif
>>> data = RawData.from_folder(basepath='data', source_dirs=['source1','source2'], target_dir='GT', axes='YX')
>>> n_images = data.size
>>> for source_x, target_y, axes, mask in data.generator():
...     pass

With the raw data specified as above, the function can be used to randomly extract patches of a given size that are suitable for training. By default, patches are normalized based on a range of percentiles computed on the raw images, which tends to lead to more robust CARE networks in our experience. If not specified otherwise, patches which are purely background are also excluded from being extracted, since they do not contain interesting structures.

If the target images have fewer axes than the source images, please use the function Examples: (a) Target image is a projection of the source image along an axis (see csbdeep.models.ProjectionCARE and example). (b) Multiple channels in the source image, but no separate channels in the target image., patch_size, n_patches_per_image, patch_axes=None, save_file=None, transforms=None, patch_filter=<function no_background_patches.<locals>._filter>, normalization=<function norm_percentiles.<locals>._normalize>, shuffle=True, verbose=True)[source]

Create normalized training data to be used for neural network training.

  • raw_data (RawData) – Object that yields matching pairs of raw images.
  • patch_size (tuple) – Shape of the patches to be extraced from raw images. Must be compatible with the number of dimensions and axes of the raw images. As a general rule, use a power of two along all XYZT axes, or at least divisible by 8.
  • n_patches_per_image (int) – Number of patches to be sampled/extracted from each raw image pair (after transformations, see below).
  • patch_axes (str or None) – Axes of the extracted patches. If None, will assume to be equal to that of transformed raw data.
  • save_file (str or None) – File name to save training data to disk in .npz format (see If None, data will not be saved.
  • transforms (list or tuple, optional) – List of Transform objects that apply additional transformations to the raw images. This can be used to augment the set of raw images (e.g., by including rotations). Set to None to disable. Default: None.
  • patch_filter (function, optional) – Function to determine for each image pair which patches are eligible to be extracted (default: no_background_patches()). Set to None to disable.
  • normalization (function, optional) – Function that takes arguments (patches_x, patches_y, x, y, mask, channel), whose purpose is to normalize the patches (patches_x, patches_y) extracted from the associated raw images (x, y, with mask; see RawData). Default: norm_percentiles().
  • shuffle (bool, optional) – Randomly shuffle all extracted patches.
  • verbose (bool, optional) – Display overview of images, transforms, etc.

Returns a tuple (X, Y, axes) with the normalized extracted patches from all (transformed) raw images and their axes. X is the array of patches extracted from source images with Y being the array of corresponding target patches. The shape of X and Y is as follows: (n_total_patches, n_channels, …). For single-channel images, n_channels will be 1.

Return type:

tuple(numpy.ndarray, numpy.ndarray, str)


ValueError – Various reasons.


>>> raw_data = RawData.from_folder(basepath='data', source_dirs=['source1','source2'], target_dir='GT', axes='ZYX')
>>> X, Y, XY_axes = create_patches(raw_data, patch_size=(32,128,128), n_patches_per_image=16), patch_size, n_patches_per_image, reduction_axes, target_axes=None, **kwargs)[source]

Create normalized training data to be used for neural network training.

In contrast to create_patches(), it is assumed that the target image has reduced dimensionality (i.e. size 1) along one or several axes (reduction_axes).

  • raw_data (RawData) – See create_patches().
  • patch_size (tuple) – See create_patches().
  • n_patches_per_image (int) – See create_patches().
  • reduction_axes (str) – Axes where the target images have a reduced dimension (i.e. size 1) compared to the source images.
  • target_axes (str) – Axes of the raw target images. If None, will be assumed to be equal to that of the raw source images.
  • kwargs (dict) – Additional parameters as in create_patches().

See create_patches(). Note that the shape of the target data will be 1 along all reduction axes.

Return type:

tuple(numpy.ndarray, numpy.ndarray, str)

Supporting functions:, percentile=99.9)[source]

Returns a patch filter to be used by create_patches() to determine for each image pair which patches are eligible for sampling. The purpose is to only sample patches from “interesting” regions of the raw image that actually contain a substantial amount of non-background signal. To that end, a maximum filter is applied to the target image to find the largest values in a region.

  • threshold (float, optional) – Scalar threshold between 0 and 1 that will be multiplied with the (outlier-robust) maximum of the image (see percentile below) to denote a lower bound. Only patches with a maximum value above this lower bound are eligible to be sampled.
  • percentile (float, optional) – Percentile value to denote the (outlier-robust) maximum of an image, i.e. should be close 100.

Function that takes an image pair (y,x) and the patch size as arguments and returns a binary mask of the same size as the image (to denote the locations eligible for sampling for create_patches()). At least one pixel of the binary mask must be True, otherwise there are no patches to sample.

Return type:



ValueError – Illegal arguments.<function sample_percentiles.<locals>.<lambda>>, relu_last=False)[source]

Normalize extracted patches based on percentiles from corresponding raw image.

  • percentiles (tuple, optional) – A tuple (pmin, pmax) or a function that returns such a tuple, where the extracted patches are (affinely) normalized in such that a value of 0 (1) corresponds to the pmin-th (pmax-th) percentile of the raw image (default: sample_percentiles()).
  • relu_last (bool, optional) – Flag to indicate whether the last activation of the CARE network is/will be using a ReLU activation function (default: False)

Function that does percentile-based normalization to be used in create_patches().

Return type:



ValueError – Illegal arguments., 3), pmax=(99.5, 99.9))[source]

Sample percentile values from a uniform distribution.

  • pmin (tuple) – Tuple of two values that denotes the interval for sampling low percentiles.
  • pmax (tuple) – Tuple of two values that denotes the interval for sampling high percentiles.

Function without arguments that returns (pl,ph), where pl (ph) is a sampled low (high) percentile.

Return type:



ValueError – Illegal arguments., X, Y, axes)[source]

Save training data in .npz format.

  • file (str) – File name
  • X (numpy.ndarray) – Array of patches extracted from source images.
  • Y (numpy.ndarray) – Array of corresponding target patches.
  • axes (str) – Axes of the extracted patches.

Anisotropic distortions

We provide the function that returns a object (see Transforms) to be used for creating training data for csbdeep.models.UpsamplingCARE and csbdeep.models.IsotropicCARE., psf, psf_axes=None, poisson_noise=False, gauss_sigma=0, subsample_axis='X', yield_target='source', crop_threshold=0.2)[source]

Simulate anisotropic distortions.

Modify the first image (obtained from input generator) along one axis to mimic the distortions that typically occur due to low resolution along the Z axis. Note that the modified image is finally upscaled to obtain the same resolution as the unmodified input image and is yielded as the ‘source’ image (see RawData). The mask from the input generator is simply passed through.

The following operations are applied to the image (in order):

  1. Convolution with PSF
  2. Poisson noise
  3. Gaussian noise
  4. Subsampling along subsample_axis
  5. Upsampling along subsample_axis (to former size).
  • subsample (float) – Subsampling factor to mimic distortions along Z.
  • psf (numpy.ndarray or None) – Point spread function (PSF) that is supposed to mimic blurring of the microscope due to reduced axial resolution. Set to None to disable.
  • psf_axes (str or None) – Axes of the PSF. If None, psf axes are assumed to be the same as of the image that it is applied to.
  • poisson_noise (bool) – Flag to indicate whether Poisson noise should be applied to the image.
  • gauss_sigma (float) – Standard deviation of white Gaussian noise to be added to the image.
  • subsample_axis (str) – Subsampling image axis (default X).
  • yield_target (str) – Which image from the input generator should be yielded by the generator (‘source’ or ‘target’). If ‘source’, the unmodified input/source image (from which the distorted image is computed) is yielded as the target image. If ‘target’, the target image from the input generator is simply passed through.
  • crop_threshold (float) – The subsample factor must evenly divide the image size along the subsampling axis to prevent potential image misalignment. If this is not the case the subsample factors are modified and the raw image may be cropped along the subsampling axis up to a fraction indicated by crop_threshold.

Returns a Transform object intended to be used with create_patches().

Return type:



ValueError – Various reasons.


A can be used to modify and augment the set of raw images before they are being used in to generate training data.


Extension of collections.namedtuple() with three fields: name, generator, and size.

  • name (str) – Name of the applied transformation.
  • generator (function) – Function that takes a generator as input and itself returns a generator; input and returned generator have the same structure as that of RawData. The purpose of the returned generator is to augment the images provided by the input generator through additional transformations. It is important that the returned generator also includes every input tuple unchanged.
  • size (int) – Number of transformations applied to every image (obtained from the input generator).
static identity()[source]
Returns:Identity transformation that passes every input through unchanged.
Return type:Transform

Data augmention

Instead of recording raw images where structures of interest appear in all possible appearance variations, it can be easier to augment the raw dataset by including some of those variations that can be easily synthesized. Typical examples are axis-aligned rotations if structures of interest can appear at arbitrary rotations. We currently haven’t implemented any such transformations, but plan to at least add axis-aligned rotations and flips later.

Other Transforms[source]

Transformation to permute images axes.

Parameters:axes (str) – Target axes, to which the input images will be permuted.
Returns:Returns a Transform object whose generator will perform the axes permutation of x, y, and mask.
Return type:Transform[source]

Transformation to crop all images (and mask).

Note that slices must be compatible with the image size.

Parameters:slices (list or tuple of slice) – List of slices to apply to each dimension of the image.
Returns:Returns a Transform object whose generator will perform image cropping of x, y, and mask.
Return type:Transform