Module deepposekit.io.DataGenerator
Expand source code
# -*- coding: utf-8 -*-
# Copyright 2018-2019 Jacob M. Graving <jgraving@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tensorflow.keras.utils import Sequence
import h5py
import numpy as np
import os
import copy
from deepposekit.io.BaseGenerator import BaseGenerator
__all__ = ["DataGenerator"]
class DataGenerator(BaseGenerator):
"""
Creates a data generator for accessing an annotation set.
Parameters
----------
datapath : str
The path to the annotations file. Must be .h5
e.g. '/path/to/file.h5'
dataset : str
The key for the image dataset in the annotations file.
e.g. 'images'
mode : str
The mode for loading and saving data.
Must be 'unannotated', 'annotated', or "full"
"""
def __init__(self, datapath, dataset="images", mode="annotated", **kwargs):
# Check annotations file
if isinstance(datapath, str):
if datapath.endswith(".h5"):
if os.path.exists(datapath):
self.datapath = datapath
else:
raise ValueError("datapath file or " "directory does not exist")
else:
raise ValueError("datapath must be .h5 file")
else:
raise TypeError("datapath must be type `str`")
if isinstance(dataset, str):
self.dataset = dataset
else:
raise TypeError("dataset must be type `str`")
with h5py.File(self.datapath, mode="r") as h5file:
# Check for annotations
if "annotations" not in list(h5file.keys()):
raise KeyError("annotations not found in annotations file")
if "annotated" not in list(h5file.keys()):
raise KeyError("annotations not found in annotations file")
if "skeleton" not in list(h5file.keys()):
raise KeyError("skeleton not found in annotations file")
if self.dataset not in list(h5file.keys()):
raise KeyError("image dataset not found in annotations file")
# Get annotations attributes
if mode not in ["full", "annotated", "unannotated"]:
raise ValueError("mode must be 'full', 'annotated', or 'unannotated'")
else:
self.mode = mode
self.annotated = np.all(h5file["annotated"].value, axis=1)
self.annotated_index = np.where(self.annotated)[0]
self.n_annotated = self.annotated_index.shape[0]
if self.n_annotated == 0 and self.mode not in ["full", "unannotated"]:
raise ValueError("The number of annotated images is zero")
self.n_keypoints = h5file["annotations"].shape[1]
self.n_samples = h5file[self.dataset].shape[0]
self.index = np.arange(self.n_samples)
self.unannotated_index = np.where(~self.annotated)[0]
self.n_unannotated = self.unannotated_index.shape[0]
# Initialize skeleton attributes
self.graph = h5file["skeleton"][:, 0]
self.swap_index = h5file["skeleton"][:, 1]
super(DataGenerator, self).__init__(**kwargs)
def compute_keypoints_shape(self):
with h5py.File(self.datapath, mode="r") as h5file:
return h5file["annotations"].shape[1:]
def compute_image_shape(self):
with h5py.File(self.datapath, mode="r") as h5file:
return h5file[self.dataset].shape[1:]
def get_indexes(self, indexes):
if self.mode is "annotated":
indexes = self.annotated_index[indexes]
elif self.mode is "unannotated":
indexes = self.unannotated_index[indexes]
else:
indexes = self.index[indexes]
return indexes
def get_images(self, indexes):
indexes = self.get_indexes(indexes)
images = []
with h5py.File(self.datapath, mode="r") as h5file:
for idx in indexes:
images.append(h5file[self.dataset][idx])
return np.stack(images)
def get_keypoints(self, indexes):
indexes = self.get_indexes(indexes)
keypoints = []
with h5py.File(self.datapath, mode="r") as h5file:
for idx in indexes:
keypoints.append(h5file["annotations"][idx])
return np.stack(keypoints)
def set_keypoints(self, indexes, keypoints):
if keypoints.shape[-1] is 3:
keypoints = keypoints[..., :2]
elif keypoints.shape[-1] is not 2:
raise ValueError("data shape does not match annotations")
indexes = self.get_indexes(indexes)
with h5py.File(self.datapath, mode="r+") as h5file:
for idx, keypoints_idx in zip(indexes, keypoints):
h5file["annotations"][idx] = keypoints_idx
def __call__(self, mode="annotated"):
if mode not in ["full", "annotated", "unannotated"]:
raise ValueError("mode must be full, annotated, or unannotated")
elif mode is "annotated" and self.n_annotated == 0:
raise ValueError(
"cannot return annotated samples, "
"number of annotated samples is zero"
)
elif mode is "unannotated" and self.n_unannotated == 0:
raise ValueError(
"cannot return unannotated samples, "
"number of unannotated samples is zero"
)
else:
self.mode = mode
return copy.deepcopy(self)
def __len__(self):
if self.mode is "annotated":
return self.n_annotated
elif self.mode is "unannotated":
return self.n_unannotated
else:
return self.n_samples
def get_config(self):
config = {"datapath": self.datapath, "dataset": self.dataset}
base_config = super(DataGenerator, self).get_config()
return dict(list(config.items()) + list(base_config.items()))
Classes
class DataGenerator (datapath, dataset='images', mode='annotated', **kwargs)
-
Creates a data generator for accessing an annotation set.
Parameters
datapath
:str
- The path to the annotations file. Must be .h5 e.g. '/path/to/file.h5'
dataset
:str
- The key for the image dataset in the annotations file. e.g. 'images'
mode
:str
- The mode for loading and saving data. Must be 'unannotated', 'annotated', or "full"
Initializes the BaseGenerator class. If graph and swap_index are not defined, they are set to a vector of -1 corresponding to keypoints shape
Expand source code
class DataGenerator(BaseGenerator): """ Creates a data generator for accessing an annotation set. Parameters ---------- datapath : str The path to the annotations file. Must be .h5 e.g. '/path/to/file.h5' dataset : str The key for the image dataset in the annotations file. e.g. 'images' mode : str The mode for loading and saving data. Must be 'unannotated', 'annotated', or "full" """ def __init__(self, datapath, dataset="images", mode="annotated", **kwargs): # Check annotations file if isinstance(datapath, str): if datapath.endswith(".h5"): if os.path.exists(datapath): self.datapath = datapath else: raise ValueError("datapath file or " "directory does not exist") else: raise ValueError("datapath must be .h5 file") else: raise TypeError("datapath must be type `str`") if isinstance(dataset, str): self.dataset = dataset else: raise TypeError("dataset must be type `str`") with h5py.File(self.datapath, mode="r") as h5file: # Check for annotations if "annotations" not in list(h5file.keys()): raise KeyError("annotations not found in annotations file") if "annotated" not in list(h5file.keys()): raise KeyError("annotations not found in annotations file") if "skeleton" not in list(h5file.keys()): raise KeyError("skeleton not found in annotations file") if self.dataset not in list(h5file.keys()): raise KeyError("image dataset not found in annotations file") # Get annotations attributes if mode not in ["full", "annotated", "unannotated"]: raise ValueError("mode must be 'full', 'annotated', or 'unannotated'") else: self.mode = mode self.annotated = np.all(h5file["annotated"].value, axis=1) self.annotated_index = np.where(self.annotated)[0] self.n_annotated = self.annotated_index.shape[0] if self.n_annotated == 0 and self.mode not in ["full", "unannotated"]: raise ValueError("The number of annotated images is zero") self.n_keypoints = h5file["annotations"].shape[1] self.n_samples = h5file[self.dataset].shape[0] self.index = np.arange(self.n_samples) self.unannotated_index = np.where(~self.annotated)[0] self.n_unannotated = self.unannotated_index.shape[0] # Initialize skeleton attributes self.graph = h5file["skeleton"][:, 0] self.swap_index = h5file["skeleton"][:, 1] super(DataGenerator, self).__init__(**kwargs) def compute_keypoints_shape(self): with h5py.File(self.datapath, mode="r") as h5file: return h5file["annotations"].shape[1:] def compute_image_shape(self): with h5py.File(self.datapath, mode="r") as h5file: return h5file[self.dataset].shape[1:] def get_indexes(self, indexes): if self.mode is "annotated": indexes = self.annotated_index[indexes] elif self.mode is "unannotated": indexes = self.unannotated_index[indexes] else: indexes = self.index[indexes] return indexes def get_images(self, indexes): indexes = self.get_indexes(indexes) images = [] with h5py.File(self.datapath, mode="r") as h5file: for idx in indexes: images.append(h5file[self.dataset][idx]) return np.stack(images) def get_keypoints(self, indexes): indexes = self.get_indexes(indexes) keypoints = [] with h5py.File(self.datapath, mode="r") as h5file: for idx in indexes: keypoints.append(h5file["annotations"][idx]) return np.stack(keypoints) def set_keypoints(self, indexes, keypoints): if keypoints.shape[-1] is 3: keypoints = keypoints[..., :2] elif keypoints.shape[-1] is not 2: raise ValueError("data shape does not match annotations") indexes = self.get_indexes(indexes) with h5py.File(self.datapath, mode="r+") as h5file: for idx, keypoints_idx in zip(indexes, keypoints): h5file["annotations"][idx] = keypoints_idx def __call__(self, mode="annotated"): if mode not in ["full", "annotated", "unannotated"]: raise ValueError("mode must be full, annotated, or unannotated") elif mode is "annotated" and self.n_annotated == 0: raise ValueError( "cannot return annotated samples, " "number of annotated samples is zero" ) elif mode is "unannotated" and self.n_unannotated == 0: raise ValueError( "cannot return unannotated samples, " "number of unannotated samples is zero" ) else: self.mode = mode return copy.deepcopy(self) def __len__(self): if self.mode is "annotated": return self.n_annotated elif self.mode is "unannotated": return self.n_unannotated else: return self.n_samples def get_config(self): config = {"datapath": self.datapath, "dataset": self.dataset} base_config = super(DataGenerator, self).get_config() return dict(list(config.items()) + list(base_config.items()))
Ancestors
- BaseGenerator
- tensorflow.python.keras.utils.data_utils.Sequence
Methods
def get_config(self)
-
Expand source code
def get_config(self): config = {"datapath": self.datapath, "dataset": self.dataset} base_config = super(DataGenerator, self).get_config() return dict(list(config.items()) + list(base_config.items()))
def get_indexes(self, indexes)
-
Expand source code
def get_indexes(self, indexes): if self.mode is "annotated": indexes = self.annotated_index[indexes] elif self.mode is "unannotated": indexes = self.unannotated_index[indexes] else: indexes = self.index[indexes] return indexes
Inherited members