Module deepposekit.models.loading

Expand source code
# -*- coding: utf-8 -*-
# Copyright 2018-2019 Jacob M. Graving <jgraving@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#    http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow.python.keras.engine import saving

import h5py
import json
import inspect

from deepposekit.models.layers.util import ImageNormalization
from deepposekit.models.layers.convolutional import (
    UpSampling2D,
    SubPixelDownscaling,
    SubPixelUpscaling,
)
from deepposekit.models.layers.deeplabcut import ImageNetPreprocess

from deepposekit.io import TrainingGenerator
from deepposekit.models.LEAP import LEAP
from deepposekit.models.StackedDenseNet import StackedDenseNet
from deepposekit.models.StackedHourglass import StackedHourglass
from deepposekit.models.DeepLabCut import DeepLabCut

MODELS = {
    "LEAP": LEAP,
    "StackedDenseNet": StackedDenseNet,
    "StackedHourglass": StackedHourglass,
    "DeepLabCut": DeepLabCut,
}


CUSTOM_LAYERS = {
    "ImageNormalization": ImageNormalization,
    "UpSampling2D": UpSampling2D,
    "SubPixelDownscaling": SubPixelDownscaling,
    "SubPixelUpscaling": SubPixelUpscaling,
    "ImageNetPreprocess": ImageNetPreprocess,
}


def load_model(path, generator=None, augmenter=None, custom_objects=None):
    """
    Load the model

    Example
    -------
    model = load_model('model.h5', augmenter)

    """
    if custom_objects:
        if isinstance(custom_objects, dict):
            base_objects = CUSTOM_LAYERS
            custom_objects = dict(
                list(base_objects.items()) + list(custom_objects.items())
            )
    else:
        custom_objects = CUSTOM_LAYERS

    if isinstance(path, str):
        if path.endswith(".h5") or path.endswith(".hdf5"):
            filepath = path
        else:
            raise ValueError("file must be .h5 file")
    else:
        raise TypeError("file must be type `str`")

    train_model = saving.load_model(filepath, custom_objects=custom_objects)

    with h5py.File(filepath, "r") as h5file:
        train_generator_config = h5file.attrs.get("train_generator_config")
        if train_generator_config is None:
            raise ValueError("No data generator found in config file")
        train_generator_config = json.loads(train_generator_config.decode("utf-8"))[
            "config"
        ]

        model_config = h5file.attrs.get("pose_model_config")
        if model_config is None:
            raise ValueError("No pose model found in config file")
        model_name = json.loads(model_config.decode("utf-8"))["class_name"]
        model_config = json.loads(model_config.decode("utf-8"))["config"]

    if generator:
        signature = inspect.signature(TrainingGenerator.__init__)
        keys = [key for key in signature.parameters.keys()]
        keys.remove("self")
        keys.remove("augmenter")
        keys.remove("generator")
        kwargs = {key: train_generator_config[key] for key in keys}
        kwargs["augmenter"] = augmenter
        kwargs["generator"] = generator
        train_generator = TrainingGenerator(**kwargs)
    else:
        train_generator = None

    Model = MODELS[model_name]
    signature = inspect.signature(Model.__init__)
    keys = [key for key in signature.parameters.keys()]
    keys.remove("self")
    keys.remove("train_generator")
    if "kwargs" in keys:
        keys.remove("kwargs")
    kwargs = {key: model_config[key] for key in keys}
    kwargs["train_generator"] = train_generator

    # Pass to skip initialization and manually intialize
    kwargs["skip_init"] = True

    model = Model(**kwargs)
    model.train_model = train_model
    model.__init_train_model__()
    model.__init_input__(model_config["image_shape"])

    kwargs = {}
    kwargs["output_shape"] = model_config["output_shape"]
    kwargs["keypoints_shape"] = model_config["keypoints_shape"]
    kwargs["downsample_factor"] = model_config["downsample_factor"]
    kwargs["output_sigma"] = model_config["output_sigma"]
    model.__init_predict_model__(**kwargs)

    return model

Functions

def load_model(path, generator=None, augmenter=None, custom_objects=None)

Load the model

Example

model = load_model('model.h5', augmenter)

Expand source code
def load_model(path, generator=None, augmenter=None, custom_objects=None):
    """
    Load the model

    Example
    -------
    model = load_model('model.h5', augmenter)

    """
    if custom_objects:
        if isinstance(custom_objects, dict):
            base_objects = CUSTOM_LAYERS
            custom_objects = dict(
                list(base_objects.items()) + list(custom_objects.items())
            )
    else:
        custom_objects = CUSTOM_LAYERS

    if isinstance(path, str):
        if path.endswith(".h5") or path.endswith(".hdf5"):
            filepath = path
        else:
            raise ValueError("file must be .h5 file")
    else:
        raise TypeError("file must be type `str`")

    train_model = saving.load_model(filepath, custom_objects=custom_objects)

    with h5py.File(filepath, "r") as h5file:
        train_generator_config = h5file.attrs.get("train_generator_config")
        if train_generator_config is None:
            raise ValueError("No data generator found in config file")
        train_generator_config = json.loads(train_generator_config.decode("utf-8"))[
            "config"
        ]

        model_config = h5file.attrs.get("pose_model_config")
        if model_config is None:
            raise ValueError("No pose model found in config file")
        model_name = json.loads(model_config.decode("utf-8"))["class_name"]
        model_config = json.loads(model_config.decode("utf-8"))["config"]

    if generator:
        signature = inspect.signature(TrainingGenerator.__init__)
        keys = [key for key in signature.parameters.keys()]
        keys.remove("self")
        keys.remove("augmenter")
        keys.remove("generator")
        kwargs = {key: train_generator_config[key] for key in keys}
        kwargs["augmenter"] = augmenter
        kwargs["generator"] = generator
        train_generator = TrainingGenerator(**kwargs)
    else:
        train_generator = None

    Model = MODELS[model_name]
    signature = inspect.signature(Model.__init__)
    keys = [key for key in signature.parameters.keys()]
    keys.remove("self")
    keys.remove("train_generator")
    if "kwargs" in keys:
        keys.remove("kwargs")
    kwargs = {key: model_config[key] for key in keys}
    kwargs["train_generator"] = train_generator

    # Pass to skip initialization and manually intialize
    kwargs["skip_init"] = True

    model = Model(**kwargs)
    model.train_model = train_model
    model.__init_train_model__()
    model.__init_input__(model_config["image_shape"])

    kwargs = {}
    kwargs["output_shape"] = model_config["output_shape"]
    kwargs["keypoints_shape"] = model_config["keypoints_shape"]
    kwargs["downsample_factor"] = model_config["downsample_factor"]
    kwargs["output_sigma"] = model_config["output_sigma"]
    model.__init_predict_model__(**kwargs)

    return model