Source code for xrmocap.data_structure.body_model.smplxd_data

# yapf: disable

import logging
import numpy as np
import torch
from typing import Any, Union

from .smplx_data import SMPLXData

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

# yapf: enable


[docs]class SMPLXDData(SMPLXData): NUM_VERTS = 10475 def __init__(self, gender: Union[Literal['female', 'male', 'neutral'], None] = None, fullpose: Union[np.ndarray, torch.Tensor, None] = None, transl: Union[np.ndarray, torch.Tensor, None] = None, betas: Union[np.ndarray, torch.Tensor, None] = None, expression: Union[np.ndarray, torch.Tensor, None] = None, displacement: Union[np.ndarray, torch.Tensor, None] = None, mask: Union[np.ndarray, torch.Tensor, None] = None, logger: Union[None, str, logging.Logger] = None) -> None: """Construct a SMPLXData instance with pre-set values. Args: gender (Union[ Literal['female', 'male', 'neutral'], None], optional): Gender of the body model. Should be one among ["female", "male", "neutral"]. Defaults to None. fullpose (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for fullpose, in shape [frame_num, 24, 3]. Defaults to None, zero-tensor will be created. transl (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for translation, in shape [frame_num, 3]. Defaults to None, zero-tensor will be created. betas (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for betas, in shape [frame_num, betas_dim]. Defaults to None, zero-tensor in shape [frame_num, 10] will be created. expression (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for expression, in shape [frame_num, expression_dim]. Defaults to None, zero-tensor in shape [frame_num, 10] will be created. displacement (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for displacement, in shape [frame_num, NUM_VERTS, 3]. Defaults to None, zero-tensor in shape [frame_num, NUM_VERTS] will be created. mask (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for framewise visibility mask, in shape [n_frame, ]. Defaults to None, one-tensor in shape [n_frame, ] will be created. logger (Union[None, str, logging.Logger], optional): Logger for logging. If None, root logger will be selected. Defaults to None. """ SMPLXData.__init__( self, gender=gender, transl=transl, fullpose=fullpose, betas=betas, expression=expression, mask=mask, logger=logger) if displacement is None and 'displacement' not in self: displacement = np.zeros( shape=(self.get_batch_size(), self.__class__.NUM_VERTS, 3)) if displacement is not None: self.set_displacement(displacement)
[docs] @classmethod def from_dict(cls, smpl_data_dict: Union['SMPLXDData', dict]) -> 'SMPLXDData': """Construct a body model data structure from a SMPLXDData, or a degraded smplxd_data in dict type. Args: smplxd_data_dict (dict): A degraded smplxd_data in dict type. Returns: SMPLXData: A SMPLXDData instance load from dict. """ smplxd_data_dict = smpl_data_dict min_keys = { 'gender', 'fullpose', 'transl', 'betas', 'expression', 'displacement' } assert min_keys <= smplxd_data_dict.keys() ret_instance = cls( gender=smplxd_data_dict['gender'], fullpose=smplxd_data_dict['fullpose'], transl=smplxd_data_dict['transl'], betas=smplxd_data_dict['betas'], expression=smplxd_data_dict['expression'], displacement=smplxd_data_dict['displacement']) return ret_instance
[docs] def set_displacement( self, displacement: Union[np.ndarray, torch.Tensor]) -> None: """Set displacement data. Args: displacement (Union[np.ndarray, torch.Tensor]): Displacement parameters in ndarray or tensor, in shape [batch_size, NUM_VERTS, 3]. Raises: TypeError: Type of displacement is not correct. """ if isinstance(displacement, torch.Tensor): displacement = displacement.detach().cpu().numpy() elif not isinstance(displacement, np.ndarray): self.logger.error('Type of displacement is not correct.\n' + f'Type: {type(displacement)}.') raise TypeError if len(displacement.shape) < 3: self.logger.error('Shape of displacement is not correct.\n' + f'Shape: {type(displacement.shape)}.') raise ValueError dict.__setitem__(self, 'displacement', displacement)
[docs] def get_displacement(self) -> np.ndarray: """Get displacement. Returns: ndarray: Displacement in shape [batch_size, NUM_VERTS, 3]. """ displacement = self.__getitem__('displacement') return displacement
def __setitem__(self, __k: Any, __v: Any) -> None: """Set item according to its key. Args: __k (Any): Key in dict. __v (Any): Value in dict. """ if __k == 'displacement': self.set_displacement(__v) else: SMPLXData.__setitem__(self, __k, __v)
[docs] def from_param_dict(self, smplxd_dict: dict) -> None: """Load SMPLX+D parameters from smplxd_dict, which is the output of a body model in most cases. Args: smplx_dict (dict): A dict of ndarray|Tensor parameters. global_orient and body_pose are necessary, jaw_pose, leye_pose, reye_pose, left_hand_pose, right_hand_pose, transl and betas are optional. Other keys are ignored. Raises: KeyError: missing necessary keys. """ necessary_keys = {'global_orient', 'body_pose', 'displacement'} if not necessary_keys.issubset(smplxd_dict): self.logger.error('Keys are not enough.\n' + f'smplxd_dict\'s keys: {smplxd_dict.keys()}') raise KeyError SMPLXData.from_param_dict(self, smplxd_dict) self.set_displacement(smplxd_dict['displacement'])
[docs] def to_param_dict(self, repeat_betas: bool = True, repeat_expression: bool = True) -> dict: """Split fullpose into global_orient, body_pose, jaw_pose, leye_pose, reye_pose, left_hand_pose, right_hand_pose, return all the necessary parameters in one dict. Args: repeat_betas (bool, optional): Whether to repeat betas when its first dim doesn't match batch_size. Defaults to True. repeat_expression (bool, optional): Whether to repeat expression when its first dim doesn't match batch_size. Defaults to True. Returns: dict: A dict of SMPLX data, whose keys are betas, global_orient, transl, global_orient, body_pose, jaw_pose, leye_pose, reye_pose, left_hand_pose, right_hand_pose, expression. """ dict_to_return = SMPLXData.to_param_dict( self, repeat_betas=repeat_betas, repeat_expression=repeat_expression) dict_to_return['displacement'] = self.get_displacement() return dict_to_return