# yapf: disable
import logging
import numpy as np
import torch
from typing import Any, Union
from xrprimer.utils.log_utils import get_logger
from xrprimer.utils.path_utils import (
Existence, check_path_existence, check_path_suffix,
)
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
# yapf: enable
[docs]class Keypoints(dict):
"""A class for multi-frame, multi-person keypoints data, based on python
dict.
keypoints, mask and convention are the three necessary keys, and we advise
you to just call Keypoints(). If you'd like to set them manually, it is
recommended to obey the following turn: convention -> keypoints -> mask.
"""
def __init__(self,
src_dict: dict = None,
dtype: Literal['torch', 'numpy', 'auto'] = 'auto',
kps: Union[np.ndarray, torch.Tensor, None] = None,
mask: Union[np.ndarray, torch.Tensor, None] = None,
convention: Union[str, None] = None,
logger: Union[None, str, logging.Logger] = None) -> None:
"""Construct a Keypoints instance with pre-set values. If any of kps,
mask, convention is provided, it will override the item in src_dict.
Args:
src_dict (dict, optional):
A dict with items in Keypoints fashion.
Defaults to None.
dtype (Literal['torch', 'numpy', 'auto'], optional):
The data type of this Keypoints instance, values will
be converted to the certain dtype when setting. If
dtype==auto, it be changed the first time set_keypoints()
is called, and never changes.
Defaults to 'auto'.
kps (Union[np.ndarray, torch.Tensor, None], optional):
A tensor or ndarray for keypoints,
kps2d in shape [n_frame, n_person, n_kps, 3],
kps3d in shape [n_frame, n_person, n_kps, 4].
Shape [n_kps, 3 or 4] is also accepted, unsqueezed
automatically. Defaults to None.
mask (Union[np.ndarray, torch.Tensor, None], optional):
A tensor or ndarray for keypoint mask,
in shape [n_frame, n_person, n_kps],
in dtype uint8.
Shape [n_kps, ] is also accepted, unsqueezed
automatically. Defaults to None.
convention (str, optional):
Convention name of the keypoints,
can be found in KEYPOINTS_FACTORY.
Defaults to None.
logger (Union[None, str, logging.Logger], optional):
Logger for logging. If None, root logger will be selected.
Defaults to None.
"""
if src_dict is not None:
super().__init__(src_dict)
else:
super().__init__()
self.logger = get_logger(logger)
if dtype == 'auto':
if kps is not None:
dtype = __get_array_type_str__(kps, logger)
elif src_dict is not None and 'keypoints' in src_dict:
dtype = __get_array_type_str__(src_dict['keypoints'], logger)
self.dtype = dtype
if convention is not None:
self.set_convention(convention)
if kps is not None:
self.set_keypoints(kps)
if mask is None and 'mask' not in self and\
'keypoints' in self:
default_n_kps = self.get_keypoints_number()
mask = np.ones(shape=(default_n_kps, ))
if mask is not None:
self.set_mask(mask)
[docs] @classmethod
def fromfile(cls, npz_path: str) -> 'Keypoints':
"""Construct a body model data structure from an npz file.
Args:
npz_path (str):
Path to a dumped npz file.
Returns:
Keypoints:
A Keypoints instance load from file.
"""
ret_instance = cls()
ret_instance.load(npz_path)
return ret_instance
[docs] def set_keypoints(self, kps: Union[np.ndarray, torch.Tensor]) -> None:
"""Set keypoints array.
Args:
kps (Union[np.ndarray, torch.Tensor]):
A tensor or ndarray for keypoints,
kps2d in shape [n_frame, n_person, n_kps, 3],
kps3d in shape [n_frame, n_person, n_kps, 4].
Shape [n_kps, 3 or 4] is also accepted, unsqueezed
automatically.
Raises:
TypeError: Type of keypoints is wrong.
ValueError: kps.shape[-1] is wrong.
ValueError: Shape of kps is wrong.
"""
if self.dtype == 'auto':
self.dtype = __get_array_type_str__(kps, self.logger)
keypoints = __get_array_in_type__(
array=kps, type=self.dtype, logger=self.logger)
# shape: frame_n, person_n, kp_n, dim+score
if keypoints.shape[-1] not in (3, 4):
self.logger.error('shape[-1] of kps2d should be 3,' +
' shape[-1] of kps3d should be 4.\n' +
f'kps.shape[-1]: {kps.shape[-1]}.')
raise ValueError
if len(keypoints.shape) == 2:
keypoints = keypoints.reshape(1, 1, keypoints.shape[0],
keypoints.shape[1])
if len(keypoints.shape) != 4:
self.logger.error('Shape of keypoints should be' +
' [n_frame, n_person, n_kps, dim+1].\n' +
f'kps.shape: {kps.shape}.')
raise ValueError
super().__setitem__('keypoints', keypoints)
[docs] def set_convention(self, convention: str) -> None:
"""Set convention name of the keypoints.
Args:
convention (str):
Convention name of the keypoints,
can be found in KEYPOINTS_FACTORY.
Raises:
TypeError: Type of convention is not str.
"""
if not isinstance(convention, str):
self.logger.error('Type of convention is not str.\n' +
f'type(convention): {type(convention)}.')
raise TypeError
super().__setitem__('convention', convention)
[docs] def set_mask(self, mask: Union[np.ndarray, torch.Tensor]) -> None:
"""Set mask of the keypoints. It should be called after the
corresponding keypoints has been set.
Args:
mask (Union[np.ndarray, torch.Tensor]):
A tensor or ndarray for keypoint mask,
in shape [n_frame, n_person, n_kps],
in dtype uint8.
Shape [n_kps, ] is also accepted, unsqueezed
automatically.
Raises:
TypeError: Type of mask is wrong.
ValueError: Shape of mask is wrong.
"""
if self.dtype == 'auto':
self.dtype = __get_array_type_str__(mask, self.logger)
mask = __get_array_in_type__(
array=mask, type=self.dtype, logger=self.logger)
if self.dtype == 'torch':
def to_type_uint8_func(data):
return data.to(dtype=torch.uint8)
else:
def to_type_uint8_func(data):
return data.astype(np.uint8)
mask = to_type_uint8_func(mask)
keypoints_shape = self.get_keypoints().shape
if len(mask.shape) == 1:
mask = mask.reshape(1, 1, len(mask))
mask = mask.repeat(keypoints_shape[0], axis=0)
mask = mask.repeat(keypoints_shape[1], axis=1)
if len(mask.shape) != 3 or \
mask.shape != keypoints_shape[:3]:
self.logger.error('Shape of mask should be' +
' [n_frame, n_person, n_kps].\n' +
f'mask.shape: {mask.shape}.' +
f'keypoints.shape: {keypoints_shape}.')
raise ValueError
super().__setitem__('mask', mask)
def __setitem__(self, __k: Any, __v: Any) -> None:
"""Set item according to its key.
Args:
__k (Any): Key in dict.
__v (Any): Value in dict.
"""
if __k == 'keypoints':
self.set_keypoints(__v)
elif __k == 'convention':
self.set_convention(__v)
elif __k == 'mask':
self.set_mask(__v)
else:
super().__setitem__(__k, __v)
[docs] def get_keypoints(self) -> Union[np.ndarray, torch.Tensor]:
"""Get keypoints array.
Returns:
np.ndarray: keypoints
"""
return self['keypoints']
[docs] def get_mask(self) -> Union[np.ndarray, torch.Tensor]:
"""Get keypoints mask.
Returns:
np.ndarray: mask
"""
return self['mask']
[docs] def get_convention(self) -> str:
"""Get keypoints convention name.
Returns:
str: convention
"""
return self['convention']
[docs] def get_frame_number(self) -> int:
"""Get frame number of keypoints.
Returns:
int: frame number
"""
return self.get_keypoints().shape[0]
[docs] def get_person_number(self) -> int:
"""Get person number of keypoints.
Returns:
int: person number
"""
return self.get_keypoints().shape[1]
[docs] def get_keypoints_number(self) -> int:
"""Get number of keypoints.
Returns:
int: keypoints number
"""
return self.get_keypoints().shape[2]
[docs] def to_tensor(self,
device: Union[torch.device, str] = 'cpu') -> 'Keypoints':
"""Return all the necessary values for keypoints expression in another
Keypoints instance, convert ndarray into Tensor.
Args:
device (Union[torch.device, str], optional):
A specified device.
Defaults to 'cpu'.
Returns:
Keypoints: An instance of Keypoints data, whose keys are
keypoints, mask, convention.
"""
kps_to_return = self.__class__(
dtype='torch',
kps=self.get_keypoints(),
mask=self.get_mask(),
convention=self.get_convention(),
logger=self.logger)
kps_to_return.set_keypoints(kps_to_return.get_keypoints().to(device))
kps_to_return.set_mask(kps_to_return.get_mask().to(device))
return kps_to_return
[docs] def to_numpy(self, ) -> 'Keypoints':
"""Return all the necessary values for keypoints expression in another
Keypoints instance, convert Tensor into numpy.
Returns:
Keypoints: An instance of Keypoints data, whose keys are
keypoints, mask, convention.
"""
kps_to_return = self.__class__(
dtype='numpy',
kps=self.get_keypoints(),
mask=self.get_mask(),
convention=self.get_convention(),
logger=self.logger)
return kps_to_return
[docs] def dump(self, npz_path: str, overwrite: bool = True):
"""Dump keys and items to an npz file.
Args:
npz_path (str):
Path to a dumped npz file.
overwrite (bool, optional):
Whether to overwrite if there is already a file.
Defaults to True.
Raises:
ValueError:
npz_path does not end with '.npz'.
FileExistsError:
When overwrite is False and file exists.
"""
if not check_path_suffix(npz_path, ['.npz']):
self.logger.error('Not an npz file.\n' + f'npz_path: {npz_path}')
raise ValueError
if not overwrite:
if check_path_existence(npz_path, 'file') == Existence.FileExist:
self.logger.error(
'File exists while overwrite option not checked.\n' +
f'npz_path: {npz_path}')
raise FileExistsError
if self.dtype == 'numpy':
dict_to_save = self
else: # else self.dtype == tensor
dict_to_save = self.to_numpy()
np.savez_compressed(npz_path, **dict_to_save)
[docs] def load(self, npz_path: str):
"""Load data from npz_path and update them to self.
Args:
npz_path (str):
Path to a dumped npz file.
"""
with np.load(npz_path, allow_pickle=True) as npz_file:
tmp_data_dict = dict(npz_file)
for key, value in tmp_data_dict.items():
if isinstance(value, np.ndarray) and\
len(value.shape) == 0:
# value is not an ndarray before dump
value = value.item()
self.__setitem__(key, value)
[docs] def clone(self) -> 'Keypoints':
"""Clone a Keypoints instance as self.
Returns:
Keypoints:
A deep copied instance of Keypoints,
with the same dtype and value as self.
"""
ret_kps = self.__class__(
dtype=self.dtype,
kps=__copy_array_tensor__(self.get_keypoints()),
mask=__copy_array_tensor__(self.get_mask()),
convention=self.get_convention(),
logger=self.logger)
return ret_kps
def __get_array_type_str__(array, logger) -> Literal['torch', 'numpy']:
if isinstance(array, torch.Tensor):
return 'torch'
elif isinstance(array, np.ndarray):
return 'numpy'
else:
logger = get_logger(logger)
logger.error('Type of array is not correct.\n' +
f'Type: {type(array)}.')
raise TypeError
def __get_array_in_type__(array: Union[torch.Tensor, np.ndarray],
type: Literal['torch', 'numpy'],
logger: Union[None, str, logging.Logger]):
logger = get_logger(logger)
if type == 'numpy':
if isinstance(array, torch.Tensor):
array = array.detach().cpu().numpy()
elif not isinstance(array, np.ndarray):
logger.error('Type of array is not correct.\n' +
f'Type: {type(array)}.')
else: # type == 'torch'
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
elif not isinstance(array, torch.Tensor):
logger.error('Type of array is not correct.\n' +
f'Type: {type(array)}.')
return array
def __copy_array_tensor__(
data: Union[np.ndarray,
torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
if isinstance(data, np.ndarray):
return data.copy()
else:
return data.clone()