This notebook demonstrates applying a CARE model for a 2D denoising task, assuming that training was already completed via 2_training.ipynb.
The trained model is assumed to be located in the folder models
with the name my_model
.
More Documentation is available at http://csbdeep.bioimagecomputing.com/doc/.
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import CARE
Load and plot the test image (with associated ground truth) and define its image axes, which will be needed later for CARE prediction.
y = imread('data/test/GT/img_0010.tif')
x = imread('data/test/low/img_0010.tif')
axes = 'YX'
print('image size =', x.shape)
print('image axes =', axes)
plt.figure(figsize=(13,5))
plt.subplot(1,2,1)
plt.imshow(x, cmap ="magma")
plt.colorbar()
plt.title("low")
plt.subplot(1,2,2)
plt.imshow(y, cmap ="magma")
plt.colorbar()
plt.title("high");
image size = (256, 256) image axes = YX
Load the trained model (located in base directory models
with name my_model
) from disk.
The configuration was saved during training and is automatically loaded when CARE
is initialized with config=None
.
model = CARE(config=None, name='my_model', basedir='models')
Loading network weights from 'weights_best.h5'.
Predict the restored image (image will be successively split into smaller tiles if there are memory issues).
%%time
restored = model.predict(x, axes)
CPU times: user 1.14 s, sys: 420 ms, total: 1.56 s Wall time: 1.35 s
Alternatively, one can directly set n_tiles
to an appropriate value to avoid the time overhead from multiple retries in case of memory issues.
Show the test image pair and the predicted restored image (middle).
The plot below shows the signals for a profile line in all threee images.
from csbdeep.utils import normalize
plt.figure(figsize=(15,10))
plot_some(np.stack([x,restored,y]),
title_list=[['low','CARE','GT']],
pmin=2,pmax=99.8);
plt.figure(figsize=(10,5))
for _x,_name in zip((x,restored,y),('low','CARE','GT')):
plt.plot(normalize(_x,1,99.7)[180], label = _name, lw = 2)
plt.legend();