This notebook demonstrates applying a probabilistic CARE model for a 2D denoising task, assuming that training was already completed via 1_training.ipynb.
The trained model is assumed to be located in the folder models
with the name my_model
.
More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import load_training_data, save_tiff_imagej_compatible
from csbdeep.models import CARE
The example data should have been downloaded in 1_training.ipynb.
Just in case, we will download it here again if it's not already present.
download_and_extract_zip_file (
url = 'http://csbdeep.bioimagecomputing.com/example_data/synthetic_disks.zip',
targetdir = 'data',
)
Files found, nothing to download. data: - synthetic_disks - synthetic_disks/data.npz
Load the validation images using during model training.
X_val, Y_val = load_training_data('data/synthetic_disks/data.npz', validation_split=0.1, verbose=True)[1]
number of training images: 180 number of validation images: 20 image size (2D): (128, 128) axes: SYXC channels in / out: 1 / 1
We will apply the trained CARE model here to restore one validation image x
(with associated ground truth y
).
y = Y_val[2,...,0]
x = X_val[2,...,0]
axes = 'YX'
print('image size =', x.shape)
print('image axes =', axes)
plt.figure(figsize=(16,10))
plot_some(np.stack([x,y]), title_list=[['input','target (GT)']]);
image size = (128, 128) image axes = YX
Load trained model (located in base directory models
with name my_model
) from disk.
The configuration was saved during training and is automatically loaded when CARE
is initialized with config=None
.
model = CARE(config=None, name='my_model', basedir='models')
Loading network weights from 'weights_best.h5'.
Predict the restored image as in the non-probabilistic case if you're only interested in a restored image.
But actually, the network returns the expected restored image for the probabilistic network outputs.
Note 1: Since the synthetic image is already normalized, we don't need to do additional normalization.
Note 2: Out of memory problems during model.predict
often indicate that the GPU is used by another process. In particular, shut down the training notebook before running the prediction (you may need to restart this notebook).
restored = model.predict(x, axes, normalizer=None)
plt.figure(figsize=(16,10))
plot_some(np.stack([x,restored]), title_list=[['input','expected restored image']]);
Save the restored image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics.
Path('results').mkdir(exist_ok=True)
save_tiff_imagej_compatible('results/%s_validation_image.tif' % model.name, restored, axes)
We now predict the per-pixel Laplace distributions and return an object to work with these.
restored_prob = model.predict_probabilistic(x, axes, normalizer=None)
Plot the mean and scale parameters of the per-pixel Laplace distributions.
plt.figure(figsize=(16,10))
plot_some(np.stack([restored_prob.mean(),restored_prob.scale()]), title_list=[['mean','scale']]);
Plot the variance and entropy parameters of the per-pixel Laplace distributions.
plt.figure(figsize=(16,10))
plot_some(np.stack([restored_prob.var(),restored_prob.entropy()]), title_list=[['variance','entropy']]);
Draw 50 samples of the distribution of the restored image. Plot the first 3 samples.
samples = np.stack(tuple(restored_prob.sampling_generator(50)))
plt.figure(figsize=(16,5))
plot_some(samples[:3], pmin=0.1,pmax=99.9);
Make an animation of the 50 samples.
from matplotlib import animation
from IPython.display import HTML
fig = plt.figure(figsize=(8,8))
im = plt.imshow(samples[0], vmin=np.percentile(samples,0.1), vmax=np.percentile(samples,99.9), cmap='magma')
plt.close()
def updatefig(j):
im.set_array(samples[j])
return [im]
anim = animation.FuncAnimation(fig, updatefig, frames=len(samples), interval=100)
HTML(anim.to_jshtml())
i = 61
line = restored_prob[i]
n = len(line)
plt.figure(figsize=(16,9))
plt.subplot(211)
plt.imshow(restored_prob.mean()[i-15:i+15], cmap='magma')
plt.plot(range(n),15*np.ones(n),'--w',linewidth=2)
plt.title('expected restored image')
plt.xlim(0,n-1); plt.axis('off')
plt.subplot(212)
q = 0.025
plt.fill_between(range(n), line.ppf(q), line.ppf(1-q), alpha=0.5, label='%.0f%% credible interval'%(100*(1-2*q)))
plt.plot(line.mean(),linewidth=3, label='expected restored image')
plt.plot(y[i],'--',linewidth=3, label='ground truth')
plt.plot(x[i],':',linewidth=1, label='input image')
plt.title('line profile')
plt.xlim(0,n-1); plt.legend(loc='lower right')
None;