Module deepposekit.io.BaseGenerator
Expand source code
# -*- coding: utf-8 -*-
# Copyright 2018-2019 Jacob M. Graving <jgraving@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tensorflow.keras.utils import Sequence
import numpy as np
__all__ = ["BaseGenerator"]
class BaseGenerator(Sequence):
"""
BaseGenerator class for abstracting data loading and saving.
Attributes that should be defined before use:
__init__
__len__
compute_image_shape
compute_keypoints_shape
get_images
get_keypoints
set_keypoints (only needed for saving data)
See docstrings for further details.
"""
def __init__(self, **kwargs):
"""
Initializes the BaseGenerator class.
If graph and swap_index are not defined,
they are set to a vector of -1 corresponding
to keypoints shape
"""
if not hasattr(self, "graph"):
self.graph = -np.ones(self.keypoints_shape[0])
if not hasattr(self, "swap_index"):
self.swap_index = -np.ones(self.keypoints_shape[0])
return
def __len__(self):
"""
Returns the number of samples in the generator as an integer (int64)
"""
raise NotImplementedError()
def compute_image_shape(self):
"""
Returns a tuple of integers describing
the image shape in the form:
(height, width, n_channels)
"""
raise NotImplementedError()
def compute_keypoints_shape(self):
"""
Returns a tuple of integers describing the
keypoints shape in the form:
(n_keypoints, 2), where 2 is the x,y coordinates
"""
raise NotImplementedError()
def get_images(self, indexes):
"""
Takes a list or array of indexes corresponding
to image-keypoint pairs in the dataset.
Returns a numpy array of images with the shape:
(1, height, width, n_channels)
"""
raise NotImplementedError()
def get_keypoints(self, indexes):
"""
Takes a list or array of indexes corresponding to
image-keypoint pairs in the dataset.
Returns a numpy array of keypoints with the shape:
(1, n_keypoints, 2), where 2 is the x,y coordinates
"""
raise NotImplementedError()
def set_keypoints(self, indexes, keypoints):
"""
Takes a list or array of indexes and corresponding
to keypoints.
Sets the values of the keypoints corresponding to the indexes
in the dataset.
"""
raise NotImplementedError()
def __call__(self):
return NotImplementedError()
@property
def image_shape(self):
return self.compute_image_shape()
def replace_nan(self, keypoints):
keypoints[np.isnan(keypoints)] = -99999
return keypoints
@property
def keypoints_shape(self):
return self.compute_keypoints_shape()
@property
def shape(self):
"""
Returns a tuple of tuples describing the data shapes
in the form:
((height, width, n_channels), (n_keypoints, 2))
"""
return (self.image_shape, self.keypoints_shape)
def get_data(self, indexes):
keypoints = self.get_keypoints(indexes)
keypoints = self.replace_nan(keypoints)
return (self.get_images(indexes), keypoints)
def set_data(self, indexes, keypoints):
self.set_keypoints(indexes, keypoints)
def _check_index(self, key):
if isinstance(key, slice):
start = key.start
stop = key.stop
if start is None:
start = 0
if stop is None:
stop = len(self)
if stop <= len(self):
indexes = range(start, stop)
else:
raise IndexError()
elif isinstance(key, (int, np.integer)):
if key < len(self):
indexes = [key]
else:
raise IndexError()
elif isinstance(key, np.ndarray):
if np.max(key) < len(self):
indexes = key.tolist()
else:
raise IndexError
elif isinstance(key, list):
if max(key) < len(self):
indexes = key
else:
raise IndexError()
else:
raise IndexError()
return indexes
def __getitem__(self, key):
indexes = self._check_index(key)
return self.get_data(indexes)
def __setitem__(self, key, keypoints):
indexes = self._check_index(key)
if len(keypoints) != len(indexes):
raise IndexError("data shape and index do not match")
self.set_data(indexes, keypoints[..., :2])
def get_config(self):
config = {
"generator": self.__class__.__name__,
"n_samples": len(self),
"image_shape": self.image_shape,
"keypoints_shape": self.keypoints_shape,
}
return config
Classes
class BaseGenerator (**kwargs)
-
BaseGenerator class for abstracting data loading and saving. Attributes that should be defined before use: init len compute_image_shape compute_keypoints_shape get_images get_keypoints set_keypoints (only needed for saving data)
See docstrings for further details.
Initializes the BaseGenerator class. If graph and swap_index are not defined, they are set to a vector of -1 corresponding to keypoints shape
Expand source code
class BaseGenerator(Sequence): """ BaseGenerator class for abstracting data loading and saving. Attributes that should be defined before use: __init__ __len__ compute_image_shape compute_keypoints_shape get_images get_keypoints set_keypoints (only needed for saving data) See docstrings for further details. """ def __init__(self, **kwargs): """ Initializes the BaseGenerator class. If graph and swap_index are not defined, they are set to a vector of -1 corresponding to keypoints shape """ if not hasattr(self, "graph"): self.graph = -np.ones(self.keypoints_shape[0]) if not hasattr(self, "swap_index"): self.swap_index = -np.ones(self.keypoints_shape[0]) return def __len__(self): """ Returns the number of samples in the generator as an integer (int64) """ raise NotImplementedError() def compute_image_shape(self): """ Returns a tuple of integers describing the image shape in the form: (height, width, n_channels) """ raise NotImplementedError() def compute_keypoints_shape(self): """ Returns a tuple of integers describing the keypoints shape in the form: (n_keypoints, 2), where 2 is the x,y coordinates """ raise NotImplementedError() def get_images(self, indexes): """ Takes a list or array of indexes corresponding to image-keypoint pairs in the dataset. Returns a numpy array of images with the shape: (1, height, width, n_channels) """ raise NotImplementedError() def get_keypoints(self, indexes): """ Takes a list or array of indexes corresponding to image-keypoint pairs in the dataset. Returns a numpy array of keypoints with the shape: (1, n_keypoints, 2), where 2 is the x,y coordinates """ raise NotImplementedError() def set_keypoints(self, indexes, keypoints): """ Takes a list or array of indexes and corresponding to keypoints. Sets the values of the keypoints corresponding to the indexes in the dataset. """ raise NotImplementedError() def __call__(self): return NotImplementedError() @property def image_shape(self): return self.compute_image_shape() def replace_nan(self, keypoints): keypoints[np.isnan(keypoints)] = -99999 return keypoints @property def keypoints_shape(self): return self.compute_keypoints_shape() @property def shape(self): """ Returns a tuple of tuples describing the data shapes in the form: ((height, width, n_channels), (n_keypoints, 2)) """ return (self.image_shape, self.keypoints_shape) def get_data(self, indexes): keypoints = self.get_keypoints(indexes) keypoints = self.replace_nan(keypoints) return (self.get_images(indexes), keypoints) def set_data(self, indexes, keypoints): self.set_keypoints(indexes, keypoints) def _check_index(self, key): if isinstance(key, slice): start = key.start stop = key.stop if start is None: start = 0 if stop is None: stop = len(self) if stop <= len(self): indexes = range(start, stop) else: raise IndexError() elif isinstance(key, (int, np.integer)): if key < len(self): indexes = [key] else: raise IndexError() elif isinstance(key, np.ndarray): if np.max(key) < len(self): indexes = key.tolist() else: raise IndexError elif isinstance(key, list): if max(key) < len(self): indexes = key else: raise IndexError() else: raise IndexError() return indexes def __getitem__(self, key): indexes = self._check_index(key) return self.get_data(indexes) def __setitem__(self, key, keypoints): indexes = self._check_index(key) if len(keypoints) != len(indexes): raise IndexError("data shape and index do not match") self.set_data(indexes, keypoints[..., :2]) def get_config(self): config = { "generator": self.__class__.__name__, "n_samples": len(self), "image_shape": self.image_shape, "keypoints_shape": self.keypoints_shape, } return config
Ancestors
- tensorflow.python.keras.utils.data_utils.Sequence
Subclasses
Instance variables
var image_shape
-
Expand source code
@property def image_shape(self): return self.compute_image_shape()
var keypoints_shape
-
Expand source code
@property def keypoints_shape(self): return self.compute_keypoints_shape()
var shape
-
Returns a tuple of tuples describing the data shapes in the form: ((height, width, n_channels), (n_keypoints, 2))
Expand source code
@property def shape(self): """ Returns a tuple of tuples describing the data shapes in the form: ((height, width, n_channels), (n_keypoints, 2)) """ return (self.image_shape, self.keypoints_shape)
Methods
def compute_image_shape(self)
-
Returns a tuple of integers describing the image shape in the form: (height, width, n_channels)
Expand source code
def compute_image_shape(self): """ Returns a tuple of integers describing the image shape in the form: (height, width, n_channels) """ raise NotImplementedError()
def compute_keypoints_shape(self)
-
Returns a tuple of integers describing the keypoints shape in the form: (n_keypoints, 2), where 2 is the x,y coordinates
Expand source code
def compute_keypoints_shape(self): """ Returns a tuple of integers describing the keypoints shape in the form: (n_keypoints, 2), where 2 is the x,y coordinates """ raise NotImplementedError()
def get_config(self)
-
Expand source code
def get_config(self): config = { "generator": self.__class__.__name__, "n_samples": len(self), "image_shape": self.image_shape, "keypoints_shape": self.keypoints_shape, } return config
def get_data(self, indexes)
-
Expand source code
def get_data(self, indexes): keypoints = self.get_keypoints(indexes) keypoints = self.replace_nan(keypoints) return (self.get_images(indexes), keypoints)
def get_images(self, indexes)
-
Takes a list or array of indexes corresponding to image-keypoint pairs in the dataset. Returns a numpy array of images with the shape: (1, height, width, n_channels)
Expand source code
def get_images(self, indexes): """ Takes a list or array of indexes corresponding to image-keypoint pairs in the dataset. Returns a numpy array of images with the shape: (1, height, width, n_channels) """ raise NotImplementedError()
def get_keypoints(self, indexes)
-
Takes a list or array of indexes corresponding to image-keypoint pairs in the dataset. Returns a numpy array of keypoints with the shape: (1, n_keypoints, 2), where 2 is the x,y coordinates
Expand source code
def get_keypoints(self, indexes): """ Takes a list or array of indexes corresponding to image-keypoint pairs in the dataset. Returns a numpy array of keypoints with the shape: (1, n_keypoints, 2), where 2 is the x,y coordinates """ raise NotImplementedError()
def replace_nan(self, keypoints)
-
Expand source code
def replace_nan(self, keypoints): keypoints[np.isnan(keypoints)] = -99999 return keypoints
def set_data(self, indexes, keypoints)
-
Expand source code
def set_data(self, indexes, keypoints): self.set_keypoints(indexes, keypoints)
def set_keypoints(self, indexes, keypoints)
-
Takes a list or array of indexes and corresponding to keypoints. Sets the values of the keypoints corresponding to the indexes in the dataset.
Expand source code
def set_keypoints(self, indexes, keypoints): """ Takes a list or array of indexes and corresponding to keypoints. Sets the values of the keypoints corresponding to the indexes in the dataset. """ raise NotImplementedError()