Module deepposekit.models.saving

Expand source code
# -*- coding: utf-8 -*-
# Copyright 2018-2019 Jacob M. Graving <jgraving@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#    http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow.python.keras.engine import saving
import h5py
import json
from deepposekit.utils.io import get_json_type


def save_model(model, path, optimizer=True):

    if isinstance(path, str):
        if path.endswith(".h5") or path.endswith(".hdf5"):
            filepath = path
        else:
            raise ValueError("file must be .h5 file")
    else:
        raise TypeError("file must be type `str`")

    saving.save_model(model.train_model, path, include_optimizer=optimizer)

    with h5py.File(filepath, "r+") as h5file:

        train_generator = model.train_generator

        h5file.attrs["train_generator_config"] = json.dumps(
            {
                "class_name": train_generator.__class__.__name__,
                "config": train_generator.get_config(),
            },
            default=get_json_type,
        ).encode("utf8")

        h5file.attrs["pose_model_config"] = json.dumps(
            {"class_name": model.__class__.__name__, "config": model.get_config()},
            default=get_json_type,
        ).encode("utf8")

Functions

def save_model(model, path, optimizer=True)
Expand source code
def save_model(model, path, optimizer=True):

    if isinstance(path, str):
        if path.endswith(".h5") or path.endswith(".hdf5"):
            filepath = path
        else:
            raise ValueError("file must be .h5 file")
    else:
        raise TypeError("file must be type `str`")

    saving.save_model(model.train_model, path, include_optimizer=optimizer)

    with h5py.File(filepath, "r+") as h5file:

        train_generator = model.train_generator

        h5file.attrs["train_generator_config"] = json.dumps(
            {
                "class_name": train_generator.__class__.__name__,
                "config": train_generator.get_config(),
            },
            default=get_json_type,
        ).encode("utf8")

        h5file.attrs["pose_model_config"] = json.dumps(
            {"class_name": model.__class__.__name__, "config": model.get_config()},
            default=get_json_type,
        ).encode("utf8")