# -*- coding: utf-8 -*-
from __future__ import print_function, unicode_literals, absolute_import, division
import datetime
import warnings
import sys
import numpy as np
from six import string_types, PY2
from functools import wraps
from .config import BaseConfig
from ..utils import _raise, load_json, save_json, axes_check_and_normalize, axes_dict, move_image_axes
from ..utils.six import Path, FileNotFoundError
from ..data import Normalizer, NoNormalizer
from ..data import Resizer, NoResizer
from .pretrained import get_model_details, get_model_instance, get_registered_models
from six import add_metaclass
from abc import ABCMeta, abstractmethod, abstractproperty
def suppress_without_basedir(warn):
def _suppress_without_basedir(f):
@wraps(f)
def wrapper(*args, **kwargs):
self = args[0]
if self.basedir is None:
warn is False or warnings.warn("Suppressing call of '%s' (due to basedir=None)." % f.__name__)
else:
return f(*args, **kwargs)
return wrapper
return _suppress_without_basedir
@add_metaclass(ABCMeta)
class BaseModel(object):
"""Base model.
Subclasses must implement :func:`_build` and :func:`_config_class`.
Parameters
----------
config : Subclass of :class:`csbdeep.models.BaseConfig` or None
Valid configuration of a model (see :func:`BaseConfig.is_valid`).
Will be saved to disk as JSON (``config.json``).
If set to ``None``, will be loaded from disk (must exist).
name : str or None
Model name. Uses a timestamp if set to ``None`` (default).
basedir : str
Directory that contains (or will contain) a folder with the given model name.
Use ``None`` to disable saving (or loading) any data to (or from) disk (regardless of other parameters).
Raises
------
FileNotFoundError
If ``config=None`` and config cannot be loaded from disk.
ValueError
Illegal arguments, including invalid configuration.
Attributes
----------
config : :class:`csbdeep.models.BaseConfig`
Configuration of the model, as provided during instantiation.
keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_
Keras neural network model.
name : str
Model name.
logdir : :class:`pathlib.Path`
Path to model folder (which stores configuration, weights, etc.)
"""
@classmethod
def from_pretrained(cls, name_or_alias=None):
try:
get_model_details(cls, name_or_alias, verbose=True)
return get_model_instance(cls, name_or_alias)
except ValueError:
if name_or_alias is not None:
print("Could not find model with name or alias '%s'" % (name_or_alias), file=sys.stderr)
sys.stderr.flush()
get_registered_models(cls, verbose=True)
def __init__(self, config, name=None, basedir='.'):
"""See class docstring."""
config is None or isinstance(config,self._config_class) or _raise (
ValueError("Invalid configuration of type '%s', was expecting type '%s'." % (type(config).__name__, self._config_class.__name__))
)
if config is not None and not config.is_valid():
invalid_attr = config.is_valid(True)[1]
raise ValueError('Invalid configuration attributes: ' + ', '.join(invalid_attr))
(not (config is None and basedir is None)) or _raise(ValueError("No config provided and cannot be loaded from disk since basedir=None."))
name is None or (isinstance(name,string_types) and len(name)>0) or _raise(ValueError("No valid name: '%s'" % str(name)))
basedir is None or isinstance(basedir,(string_types,Path)) or _raise(ValueError("No valid basedir: '%s'" % str(basedir)))
self.config = config
self.name = name if name is not None else datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S.%f")
self.basedir = Path(basedir) if basedir is not None else None
if config is not None:
# config was provided -> update before it is saved to disk
self._update_and_check_config()
self._set_logdir()
if config is None:
# config was loaded from disk -> update it after loading
self._update_and_check_config()
self._model_prepared = False
self.keras_model = self._build()
if config is None:
self._find_and_load_weights()
def __repr__(self):
s = ("{self.__class__.__name__}({self.name}): {self.config.axes} → {self._axes_out}\n".format(self=self) +
"├─ Directory: {}\n".format(self.logdir.resolve() if self.basedir is not None else None) +
self._repr_extra() +
"└─ {self.config}".format(self=self))
return s.encode('utf-8') if PY2 else s
def _repr_extra(self):
return ""
def _update_and_check_config(self):
pass
@suppress_without_basedir(warn=False)
def _set_logdir(self):
self.logdir = self.basedir / self.name
config_file = self.logdir / 'config.json'
if self.config is None:
if config_file.exists():
config_dict = load_json(str(config_file))
self.config = self._config_class(**config_dict)
if not self.config.is_valid():
invalid_attr = self.config.is_valid(True)[1]
raise ValueError('Invalid attributes in loaded config: ' + ', '.join(invalid_attr))
else:
raise FileNotFoundError("config file doesn't exist: %s" % str(config_file.resolve()))
else:
if self.logdir.exists():
warnings.warn('output path for model already exists, files may be overwritten: %s' % str(self.logdir.resolve()))
self.logdir.mkdir(parents=True, exist_ok=True)
save_json(vars(self.config), str(config_file))
@suppress_without_basedir(warn=False)
def _find_and_load_weights(self,prefer='best'):
from itertools import chain
# get all weight files and sort by modification time descending (newest first)
weights_ext = ('*.h5','*.hdf5')
weights_files = chain(*(self.logdir.glob(ext) for ext in weights_ext))
weights_files = reversed(sorted(weights_files, key=lambda f: f.stat().st_mtime))
weights_files = list(weights_files)
if len(weights_files) == 0:
warnings.warn("Couldn't find any network weights (%s) to load." % ', '.join(weights_ext))
return
weights_preferred = list(filter(lambda f: prefer in f.name, weights_files))
weights_chosen = weights_preferred[0] if len(weights_preferred)>0 else weights_files[0]
print("Loading network weights from '%s'." % weights_chosen.name)
self.load_weights(weights_chosen.name)
@abstractmethod
def _build(self):
""" Create and return a Keras model. """
@suppress_without_basedir(warn=True)
def load_weights(self, name='weights_best.h5'):
"""Load neural network weights from model folder.
Parameters
----------
name : str
Name of HDF5 weight file (as saved during or after training).
"""
self.keras_model.load_weights(str(self.logdir/name))
def _checkpoint_callbacks(self):
callbacks = []
if self.basedir is not None:
from ..utils.tf import keras_import
ModelCheckpoint = keras_import('callbacks', 'ModelCheckpoint')
if self.config.train_checkpoint is not None:
callbacks.append(ModelCheckpoint(str(self.logdir / self.config.train_checkpoint), save_best_only=True, save_weights_only=True))
if self.config.train_checkpoint_epoch is not None:
callbacks.append(ModelCheckpoint(str(self.logdir / self.config.train_checkpoint_epoch), save_best_only=False, save_weights_only=True))
return callbacks
def _training_finished(self):
if self.basedir is not None:
if self.config.train_checkpoint_last is not None:
self.keras_model.save_weights(str(self.logdir / self.config.train_checkpoint_last))
if self.config.train_checkpoint is not None:
print()
self._find_and_load_weights(self.config.train_checkpoint)
if self.config.train_checkpoint_epoch is not None:
try:
# remove temporary weights
(self.logdir / self.config.train_checkpoint_epoch).unlink()
except FileNotFoundError:
pass
@suppress_without_basedir(warn=True)
def export_TF(self, fname=None):
raise NotImplementedError()
def _make_permute_axes(self, img_axes_in, net_axes_in, net_axes_out=None, img_axes_out=None):
# img_axes_in -> net_axes_in ---NN--> net_axes_out -> img_axes_out
if net_axes_out is None:
net_axes_out = net_axes_in
if img_axes_out is None:
img_axes_out = img_axes_in
assert 'C' in net_axes_in and 'C' in net_axes_out
assert not 'C' in img_axes_in or 'C' in img_axes_out
def _permute_axes(data,undo=False):
if data is None:
return None
if undo:
if 'C' in img_axes_in:
return move_image_axes(data, net_axes_out, img_axes_out, True)
else:
# input is single-channel and has no channel axis
data = move_image_axes(data, net_axes_out, img_axes_out+'C', True)
if data.shape[-1] == 1:
# output is single-channel -> remove channel axis
data = data[...,0]
return data
else:
return move_image_axes(data, img_axes_in, net_axes_in, True)
return _permute_axes
def _check_normalizer_resizer(self, normalizer, resizer):
if normalizer is None:
normalizer = NoNormalizer()
if resizer is None:
resizer = NoResizer()
isinstance(resizer,Resizer) or _raise(ValueError())
isinstance(normalizer,Normalizer) or _raise(ValueError())
if normalizer.do_after:
if self.config.n_channel_in != self.config.n_channel_out:
warnings.warn('skipping normalization step after prediction because ' +
'number of input and output channels differ.')
return normalizer, resizer
@property
def _axes_out(self):
return self.config.axes
@abstractproperty
def _config_class(self):
""" Class of config to be used for this model. """