Module deepposekit.io.DLCDataGenerator

Expand source code
# -*- coding: utf-8 -*-
# Copyright 2018-2019 Jacob M. Graving <jgraving@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#    http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd
import os
import cv2
import yaml
import glob

from deepposekit.io.BaseGenerator import BaseGenerator

__all__ = ["DLCDataGenerator"]


class DLCDataGenerator(BaseGenerator):
    """
    Creates a data generator for accessing a DeepLabCut annotation set.

    Parameters
    ----------
    project_path : str
        Path to the project with config.yaml and images.
        e.g. '/path/to/project/'
    """

    def __init__(self, project_path, **kwargs):
        self.project_path = project_path
        self.annotations_path = glob.glob(self.project_path + "/**/**/*.h5")
        annotations = [pd.read_hdf(datapath) for datapath in self.annotations_path]
        self.annotations = pd.concat(annotations)

        with open(project_path + "/config.yaml", "r") as config_file:
            self.dlcconfig = yaml.load(config_file, Loader=yaml.SafeLoader)
        self.n_keypoints = len(self.dlcconfig["bodyparts"])

        self.bodyparts = self.dlcconfig["bodyparts"]
        self.scorer = self.dlcconfig["scorer"]

        self.n_samples = self.annotations.shape[0]
        self.index = np.arange(self.n_samples)

        super(DLCDataGenerator, self).__init__(**kwargs)

    def compute_image_shape(self):
        return self.get_images([0]).shape[1:]

    def compute_keypoints_shape(self):
        return (self.n_keypoints, 2)

    def get_images(self, indexes):
        indexes = self.index[indexes]
        images = []
        for idx in indexes:
            row = self.annotations.iloc[idx]
            image_name = row.name
            filepath = self.project_path + image_name
            if os.path.exists(filepath):
                images.append(cv2.imread(filepath))
            else:
                raise IndexError("image `{}` does not exist".format(image_name))
        return np.stack(images)

    def get_keypoints(self, indexes):
        indexes = self.index[indexes]
        keypoints = []
        for idx in indexes:
            row = self.annotations.iloc[idx]
            coords = []
            for part in self.bodyparts:
                x = row[(self.scorer, part, "x")]
                y = row[(self.scorer, part, "y")]
                coords.append([x, y])
            coords = np.array(coords)
            keypoints.append(coords)
        return np.stack(keypoints)

    def __len__(self):
        return self.n_samples

    def get_config(self):
        config = {"project_path": self.project_path}
        base_config = super(DLCDataGenerator, self).get_config()
        return dict(list(config.items()) + list(base_config.items()))


if __name__ == "__main__":
    data_generator = DLCDataGenerator(
        project_path="./deeplabcut/examples/openfield-Pranav-2018-10-30/"
    )

Classes

class DLCDataGenerator (project_path, **kwargs)

Creates a data generator for accessing a DeepLabCut annotation set.

Parameters

project_path : str
Path to the project with config.yaml and images. e.g. '/path/to/project/'

Initializes the BaseGenerator class. If graph and swap_index are not defined, they are set to a vector of -1 corresponding to keypoints shape

Expand source code
class DLCDataGenerator(BaseGenerator):
    """
    Creates a data generator for accessing a DeepLabCut annotation set.

    Parameters
    ----------
    project_path : str
        Path to the project with config.yaml and images.
        e.g. '/path/to/project/'
    """

    def __init__(self, project_path, **kwargs):
        self.project_path = project_path
        self.annotations_path = glob.glob(self.project_path + "/**/**/*.h5")
        annotations = [pd.read_hdf(datapath) for datapath in self.annotations_path]
        self.annotations = pd.concat(annotations)

        with open(project_path + "/config.yaml", "r") as config_file:
            self.dlcconfig = yaml.load(config_file, Loader=yaml.SafeLoader)
        self.n_keypoints = len(self.dlcconfig["bodyparts"])

        self.bodyparts = self.dlcconfig["bodyparts"]
        self.scorer = self.dlcconfig["scorer"]

        self.n_samples = self.annotations.shape[0]
        self.index = np.arange(self.n_samples)

        super(DLCDataGenerator, self).__init__(**kwargs)

    def compute_image_shape(self):
        return self.get_images([0]).shape[1:]

    def compute_keypoints_shape(self):
        return (self.n_keypoints, 2)

    def get_images(self, indexes):
        indexes = self.index[indexes]
        images = []
        for idx in indexes:
            row = self.annotations.iloc[idx]
            image_name = row.name
            filepath = self.project_path + image_name
            if os.path.exists(filepath):
                images.append(cv2.imread(filepath))
            else:
                raise IndexError("image `{}` does not exist".format(image_name))
        return np.stack(images)

    def get_keypoints(self, indexes):
        indexes = self.index[indexes]
        keypoints = []
        for idx in indexes:
            row = self.annotations.iloc[idx]
            coords = []
            for part in self.bodyparts:
                x = row[(self.scorer, part, "x")]
                y = row[(self.scorer, part, "y")]
                coords.append([x, y])
            coords = np.array(coords)
            keypoints.append(coords)
        return np.stack(keypoints)

    def __len__(self):
        return self.n_samples

    def get_config(self):
        config = {"project_path": self.project_path}
        base_config = super(DLCDataGenerator, self).get_config()
        return dict(list(config.items()) + list(base_config.items()))

Ancestors

  • BaseGenerator
  • tensorflow.python.keras.utils.data_utils.Sequence

Methods

def get_config(self)
Expand source code
def get_config(self):
    config = {"project_path": self.project_path}
    base_config = super(DLCDataGenerator, self).get_config()
    return dict(list(config.items()) + list(base_config.items()))

Inherited members