Source code for xrmocap.data_structure.body_model.smpl_data

# yapf: disable
import logging
import numpy as np
import torch
from typing import Any, Union
from xrprimer.utils.log_utils import get_logger
from xrprimer.utils.path_utils import (
    Existence, check_path_existence, check_path_suffix,
)

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

# yapf: enable


[docs]class SMPLData(dict): BODY_POSE_LEN = 23 DEFAULT_BODY_JOINTS_NUM = 45 BODY_POSE_KEYS = { 'global_orient', 'body_pose', } FULL_POSE_KEYS = { 'global_orient', 'body_pose', } def __init__(self, gender: Union[Literal['female', 'male', 'neutral'], None] = None, fullpose: Union[np.ndarray, torch.Tensor, None] = None, transl: Union[np.ndarray, torch.Tensor, None] = None, betas: Union[np.ndarray, torch.Tensor, None] = None, mask: Union[np.ndarray, torch.Tensor, None] = None, logger: Union[None, str, logging.Logger] = None) -> None: """Construct a SMPLData instance with pre-set values. Args: gender (Union[ Literal['female', 'male', 'neutral'], None], optional): Gender of the body model. Should be one among ["female", "male", "neutral"]. Defaults to None. fullpose (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for fullpose, in shape [n_frame, 24, 3]. Defaults to None, zero-tensor will be created. transl (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for translation, in shape [n_frame, 3]. Defaults to None, zero-tensor will be created. betas (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for translation, in shape [n_frame, betas_dim]. Defaults to None, zero-tensor in shape [n_frame, 10] will be created. mask (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for framewise visibility mask, in shape [n_frame, ]. Defaults to None, one-tensor in shape [n_frame, ] will be created. logger (Union[None, str, logging.Logger], optional): Logger for logging. If None, root logger will be selected. Defaults to None. """ super().__init__() self.n_body_joints = self.__class__.DEFAULT_BODY_JOINTS_NUM self.logger = get_logger(logger) if gender is None and 'gender' not in self: gender = 'neutral' if gender is not None: self.set_gender(gender) if fullpose is None and 'fullpose' not in self: fullpose_dim = self.__class__.get_fullpose_dim() fullpose = np.zeros(shape=(1, fullpose_dim, 3)) if fullpose is not None: self.set_fullpose(fullpose) if transl is None and 'transl' not in self: transl = np.zeros(shape=(self.get_batch_size(), 3)) if transl is not None: self.set_transl(transl) if betas is None and 'betas' not in self: betas = np.zeros(shape=(self.get_batch_size(), 10)) if betas is not None: self.set_betas(betas) if mask is None and 'mask' not in self: mask = np.ones(shape=(self.get_batch_size())) if mask is not None: self.set_mask(mask)
[docs] @classmethod def fromfile(cls, npz_path: str) -> 'SMPLData': """Construct a body model data structure from an npz file. Args: npz_path (str): Path to a dumped npz file. Returns: SMPLData: A SMPLData instance load from file. """ ret_instance = cls() ret_instance.load(npz_path) return ret_instance
[docs] @classmethod def from_dict(cls, smpl_data_dict: Union['SMPLData', dict]) -> 'SMPLData': """Construct a body model data structure from a SMPLData, or a degraded smpl_data in dict type. Args: smpl_data_dict (dict): A degraded smpl_data in dict type. Returns: SMPLData: A SMPLData instance load from dict. """ min_keys = {'gender', 'fullpose', 'transl', 'betas'} assert min_keys <= smpl_data_dict.keys() ret_instance = cls( gender=smpl_data_dict['gender'], fullpose=smpl_data_dict['fullpose'], transl=smpl_data_dict['transl'], betas=smpl_data_dict['betas'], ) return ret_instance
[docs] @classmethod def get_fullpose_dim(cls) -> int: """Get dimension of full pose. Returns: int: Dim value. Full pose shall be in shape (frame_n, dim, 3) """ global_orient_dim = 1 ret_sum = global_orient_dim + cls.BODY_POSE_LEN return ret_sum
[docs] def set_gender( self, gender: Literal['female', 'male', 'neutral'] = 'neutral') -> None: """Set gender. Args: gender (Literal["female", "male", "neutral"], optional): Gender of the body model. Should be one among ["female", "male", "neutral"]. Defaults to 'neutral'. Raises: ValueError: Value of gender is not correct. """ if gender in ['female', 'male', 'neutral']: super().__setitem__('gender', gender) else: self.logger.error('Value of gender is not correct.\n' + f'gender: {gender}.') raise ValueError
[docs] def set_fullpose(self, fullpose: Union[np.ndarray, torch.Tensor]) -> None: """Set full pose data. Args: fullpose (Union[np.ndarray, torch.Tensor]): Full pose in ndarray or tensor, in shape [batch_size, fullpose_dim, 3]. global_orient at [:, 0, :]. Raises: TypeError: Type of fullpose is not correct. """ fullpose_dim = self.__class__.get_fullpose_dim() if isinstance(fullpose, np.ndarray): fullpose_np = fullpose.reshape(-1, fullpose_dim, 3) elif isinstance(fullpose, torch.Tensor): fullpose_np = fullpose.detach().cpu().numpy().reshape( -1, fullpose_dim, 3) else: self.logger.error('Type of fullpose is not correct.\n' + f'Type: {type(fullpose)}.') raise TypeError super().__setitem__('fullpose', fullpose_np)
[docs] def set_transl(self, transl: Union[np.ndarray, torch.Tensor]) -> None: """Set translation data. Args: transl (Union[np.ndarray, torch.Tensor]): Translation in ndarray or tensor, in shape [batch_size, 3]. Raises: TypeError: Type of transl is not correct. """ if isinstance(transl, np.ndarray): transl_np = transl.reshape(-1, 3) elif isinstance(transl, torch.Tensor): transl_np = transl.detach().cpu().numpy().reshape(-1, 3) else: self.logger.error('Type of transl is not correct.\n' + f'Type: {type(transl)}.') raise TypeError super().__setitem__('transl', transl_np)
[docs] def set_betas(self, betas: Union[np.ndarray, torch.Tensor]) -> None: """Set betas data. Args: betas (Union[np.ndarray, torch.Tensor]): Body shape parameters in ndarray or tensor, in shape [batch_size, n]. n stands for any positive int, typically it's 10. Raises: TypeError: Type of betas is not correct. """ if isinstance(betas, torch.Tensor): betas = betas.detach().cpu().numpy() elif not isinstance(betas, np.ndarray): self.logger.error('Type of betas is not correct.\n' + f'Type: {type(betas)}.') raise TypeError if len(betas.shape) == 1: betas = betas[np.newaxis, ...] betas_dim = betas.shape[-1] betas_np = betas.reshape(-1, betas_dim) super().__setitem__('betas', betas_np)
[docs] def set_mask(self, mask: Union[np.ndarray, torch.Tensor]) -> None: """Set framewise mask data. Args: mask (Union[np.ndarray, torch.Tensor]): Visibility mask in ndarray or tensor, in shape [batch_size, ]. Raises: TypeError: Type of mask is not correct. """ if isinstance(mask, np.ndarray): mask_np = mask.reshape(-1).astype(np.uint8) elif isinstance(mask, torch.Tensor): mask_np = mask.detach().cpu().numpy().reshape(-1).astype(np.uint8) else: self.logger.error('Type of mask is not correct.\n' + f'Type: {type(mask)}.') raise TypeError super().__setitem__('mask', mask_np)
def __setitem__(self, __k: Any, __v: Any) -> None: """Set item according to its key. Args: __k (Any): Key in dict. __v (Any): Value in dict. """ if __k == 'gender': self.set_gender(__v) elif __k == 'transl': self.set_transl(__v) elif __k == 'fullpose': self.set_fullpose(__v) elif __k == 'betas': self.set_betas(__v) elif __k == 'mask': self.set_mask(__v) else: super().__setitem__(__k, __v)
[docs] def get_batch_size(self) -> int: """Get batch size. Returns: int: batch size of fullpose. """ return self.__getitem__('fullpose').shape[0]
[docs] def get_fullpose(self) -> np.ndarray: """Get fullpose. Returns: ndarray: fullpose in shape [batch_size, fullpose_dim, 3]. """ fullpose = self.__getitem__('fullpose') return fullpose
[docs] def get_global_orient(self) -> np.ndarray: """Get global_orient. Returns: ndarray: global_orient in shape [batch_size, 3]. """ fullpose = self.get_fullpose() global_orient = fullpose[:, 0].reshape(-1, 3) return global_orient
[docs] def get_body_pose(self) -> np.ndarray: """Get body_pose. Returns: ndarray: body_pose in shape [batch_size, BODY_POSE_LEN, 3]. """ fullpose = self.get_fullpose() batch_size = self.get_batch_size() start_idx = 1 body_pose = fullpose[:, start_idx:start_idx + self.__class__.BODY_POSE_LEN].reshape( batch_size, -1, 3) return body_pose
[docs] def get_transl(self) -> np.ndarray: """Get translation. Returns: ndarray: translation in shape [batch_size, 3]. """ return self.__getitem__('transl').reshape(-1, 3)
[docs] def get_betas(self, repeat_betas: bool = True) -> np.ndarray: """Get betas. Args: repeat_betas (bool, optional): Whether to repeat betas when its first dim doesn't match batch_size. Defaults to True. Returns: ndarray: betas in shape [batch_size, betas_dims] or [1, betas_dims]. """ batch_size = self.get_batch_size() betas = self.__getitem__('betas') if repeat_betas and\ betas.shape[0] == 1 and\ betas.shape[0] != batch_size: betas = betas.repeat(repeats=batch_size, axis=0) return betas
[docs] def get_mask(self) -> np.ndarray: """Get mask. Returns: ndarray: mask in shape [batch_size, ]. """ return self.__getitem__('mask').reshape(-1)
[docs] def get_gender(self) -> str: """Get gender. Returns: str: gender in ['neutral', 'female', 'male']. """ return self.__getitem__('gender')
[docs] def to_param_dict(self, repeat_betas: bool = True) -> dict: """Split fullpose into global_orient and body_pose, return all the necessary parameters in one dict. Args: repeat_betas (bool, optional): Whether to repeat betas when its first dim doesn't match batch_size. Defaults to True. Returns: dict: A dict of SMPL data, whose keys are betas, body_pose, global_orient and transl. """ body_pose = self.get_body_pose().reshape(self.get_batch_size(), -1) global_orient = self.get_global_orient().reshape( self.get_batch_size(), 3) transl = self.get_transl() betas = self.get_betas(repeat_betas=repeat_betas) dict_to_return = { 'betas': betas, 'body_pose': body_pose, 'global_orient': global_orient, 'transl': transl, } return dict_to_return
[docs] def to_tensor_dict(self, repeat_betas: bool = True, device: Union[torch.device, str] = 'cpu') -> dict: """It is almost same as self.to_param_dict, but all the values are tensors in a specified device. Split fullpose into global_orient and body_pose, return all the necessary parameters in one dict. Args: repeat_betas (bool, optional): Whether to repeat betas when its first dim doesn't match batch_size. Defaults to True. device (Union[torch.device, str], optional): A specified device. Defaults to CPU_DEVICE. Defaults to 'cpu'. Returns: dict: A dict of SMPL data, whose keys are betas, body_pose, global_orient and transl. """ np_dict = self.to_param_dict(repeat_betas=repeat_betas) dict_to_return = {} for key, value in np_dict.items(): dict_to_return[key] = torch.tensor( value, device=device, dtype=torch.float32) return dict_to_return
[docs] def dump(self, npz_path: str, overwrite: bool = True): """Dump keys and items to an npz file. Args: npz_path (str): Path to a dumped npz file. overwrite (bool, optional): Whether to overwrite if there is already a file. Defaults to True. Raises: ValueError: npz_path does not end with '.npz'. FileExistsError: When overwrite is False and file exists. """ if not check_path_suffix(npz_path, ['.npz']): self.logger.error('Not an npz file.\n' + f'npz_path: {npz_path}') raise ValueError if not overwrite: if check_path_existence(npz_path, 'file') == Existence.FileExist: self.logger.error( 'File exists while overwrite option not checked.\n' + f'npz_path: {npz_path}') raise FileExistsError np.savez_compressed(npz_path, **self)
[docs] def from_param_dict(self, smpl_dict: dict) -> None: """Load SMPL parameters from smpl_dict, which is the output of a body model in most cases. Args: smpl_dict (dict): A dict of ndarray|Tensor parameters. global_orient and body_pose are necessary, transl and betas are optional. Other keys are ignored. Raises: KeyError: missing necessary keys. """ if not self.__class__.BODY_POSE_KEYS.issubset(smpl_dict): self.logger.error('Keys are not enough.\n' + f'smpl_dict\'s keys: {smpl_dict.keys()}') raise KeyError global_orient = smpl_dict['global_orient'] batch_size = global_orient.shape[0] \ if len(global_orient.shape) > 1 else 1 global_orient = smpl_dict['global_orient'].reshape(batch_size, 3) body_pose = smpl_dict['body_pose'].reshape(batch_size, -1) if isinstance(global_orient, torch.Tensor): def concat_func(data_list, dim): return torch.cat(data_list, dim=dim) elif isinstance(global_orient, np.ndarray): def concat_func(data_list, dim): return np.concatenate(data_list, axis=dim) fullpose = concat_func([global_orient, body_pose], dim=1) self.set_fullpose(fullpose) if 'transl' in smpl_dict: self.set_transl(smpl_dict['transl']) if 'betas' in smpl_dict: self.set_betas(smpl_dict['betas']) # reset mask to make sure correspondence self.set_mask(np.ones(shape=(self.get_batch_size())))
[docs] def load(self, npz_path: str): """Load data from npz_path and update them to self. Args: npz_path (str): Path to a dumped npz file. """ with np.load(npz_path, allow_pickle=True) as npz_file: tmp_data_dict = dict(npz_file) for key, value in tmp_data_dict.items(): if isinstance(value, np.ndarray) and\ len(value.shape) == 0: # value is not an ndarray before dump value = value.item() self.__setitem__(key, value)