Source code for csbdeep.io

# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals, absolute_import, division
from six.moves import range, zip, map, reduce, filter
from six import string_types

import numpy as np
try:
    from tifffile import imwrite as imsave
except ImportError:
    from tifffile import imsave
import warnings

from ..utils import _raise, axes_check_and_normalize, axes_dict, move_image_axes, move_channel_for_backend, backend_channels_last
from ..utils.six import Path



def save_tiff_imagej_compatible(file, img, axes, **imsave_kwargs):
    """Save image in ImageJ-compatible TIFF format.

    Parameters
    ----------
    file : str
        File name
    img : numpy.ndarray
        Image
    axes: str
        Axes of ``img``
    imsave_kwargs : dict, optional
        Keyword arguments for :func:`tifffile.imsave`

    """
    axes = axes_check_and_normalize(axes,img.ndim,disallowed='S')

    # convert to imagej-compatible data type
    t = img.dtype
    if   'float' in t.name: t_new = np.float32
    elif 'uint'  in t.name: t_new = np.uint16 if t.itemsize >= 2 else np.uint8
    elif 'int'   in t.name: t_new = np.int16
    else:                   t_new = t
    img = img.astype(t_new, copy=False)
    if t != t_new:
        warnings.warn("Converting data type from '%s' to ImageJ-compatible '%s'." % (t, np.dtype(t_new)))

    # move axes to correct positions for imagej
    img = move_image_axes(img, axes, 'TZCYX', True)

    imsave_kwargs['imagej'] = True
    imsave(file, img, **imsave_kwargs)



[docs]def load_training_data(file, validation_split=0, axes=None, n_images=None, verbose=False): """Load training data from file in ``.npz`` format. The data file is expected to have the keys: - ``X`` : Array of training input images. - ``Y`` : Array of corresponding target images. - ``axes`` : Axes of the training images. Parameters ---------- file : str File name validation_split : float Fraction of images to use as validation set during training. axes: str, optional Must be provided in case the loaded data does not contain ``axes`` information. n_images : int, optional Can be used to limit the number of images loaded from data. verbose : bool, optional Can be used to display information about the loaded images. Returns ------- tuple( tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`), str ) Returns two tuples (`X_train`, `Y_train`), (`X_val`, `Y_val`) of training and validation sets and the axes of the input images. The tuple of validation data will be ``None`` if ``validation_split = 0``. """ f = np.load(file) X, Y = f['X'], f['Y'] if axes is None: axes = f['axes'] axes = axes_check_and_normalize(axes) # assert X.shape == Y.shape assert X.ndim == Y.ndim assert len(axes) == X.ndim assert 'C' in axes if n_images is None: n_images = X.shape[0] assert X.shape[0] == Y.shape[0] assert 0 < n_images <= X.shape[0] assert 0 <= validation_split < 1 X, Y = X[:n_images], Y[:n_images] channel = axes_dict(axes)['C'] if validation_split > 0: n_val = int(round(n_images * validation_split)) n_train = n_images - n_val assert 0 < n_val and 0 < n_train X_t, Y_t = X[-n_val:], Y[-n_val:] X, Y = X[:n_train], Y[:n_train] assert X.shape[0] == n_train and X_t.shape[0] == n_val X_t = move_channel_for_backend(X_t,channel=channel) Y_t = move_channel_for_backend(Y_t,channel=channel) X = move_channel_for_backend(X,channel=channel) Y = move_channel_for_backend(Y,channel=channel) axes = axes.replace('C','') # remove channel if backend_channels_last(): axes = axes+'C' else: axes = axes[:1]+'C'+axes[1:] data_val = (X_t,Y_t) if validation_split > 0 else None if verbose: ax = axes_dict(axes) n_train, n_val = len(X), len(X_t) if validation_split>0 else 0 image_size = tuple( X.shape[ax[a]] for a in axes if a in 'TZYX' ) n_dim = len(image_size) n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']] print('number of training images:\t', n_train) print('number of validation images:\t', n_val) print('image size (%dD):\t\t'%n_dim, image_size) print('axes:\t\t\t\t', axes) print('channels in / out:\t\t', n_channel_in, '/', n_channel_out) return (X,Y), data_val, axes
[docs]def save_training_data(file, X, Y, axes): """Save training data in ``.npz`` format. Parameters ---------- file : str File name X : :class:`numpy.ndarray` Array of patches extracted from source images. Y : :class:`numpy.ndarray` Array of corresponding target patches. axes : str Axes of the extracted patches. """ isinstance(file,(Path,string_types)) or _raise(ValueError()) file = Path(file).with_suffix('.npz') file.parent.mkdir(parents=True,exist_ok=True) axes = axes_check_and_normalize(axes) len(axes) == X.ndim or _raise(ValueError()) np.savez(str(file), X=X, Y=Y, axes=axes)