# yapf: disable
import logging
import numpy as np
import torch
from typing import Any, Union
from .smpl_data import SMPLData
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
# yapf: enable
[docs]class SMPLXData(SMPLData):
BODY_POSE_LEN = 21
HAND_POSE_LEN = 15
JAW_POSE_LEN = 1
EYE_POSE_LEN = 1
DEFAULT_BODY_JOINTS_NUM = 144
BODY_POSE_KEYS = {
'global_orient',
'body_pose',
}
FULL_POSE_KEYS = {
'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose',
'jaw_pose', 'leye_pose', 'reye_pose'
}
def __init__(self,
gender: Union[Literal['female', 'male', 'neutral'],
None] = None,
fullpose: Union[np.ndarray, torch.Tensor, None] = None,
transl: Union[np.ndarray, torch.Tensor, None] = None,
betas: Union[np.ndarray, torch.Tensor, None] = None,
expression: Union[np.ndarray, torch.Tensor, None] = None,
mask: Union[np.ndarray, torch.Tensor, None] = None,
logger: Union[None, str, logging.Logger] = None) -> None:
"""Construct a SMPLXData instance with pre-set values.
Args:
gender (Union[
Literal['female', 'male', 'neutral'], None], optional):
Gender of the body model.
Should be one among ["female", "male", "neutral"].
Defaults to None.
fullpose (Union[np.ndarray, torch.Tensor, None], optional):
A tensor or ndarray for fullpose, in shape [frame_num, 24, 3].
Defaults to None, zero-tensor will be created.
transl (Union[np.ndarray, torch.Tensor, None], optional):
A tensor or ndarray for translation, in shape [frame_num, 3].
Defaults to None, zero-tensor will be created.
betas (Union[np.ndarray, torch.Tensor, None], optional):
A tensor or ndarray for betas,
in shape [frame_num, betas_dim].
Defaults to None,
zero-tensor in shape [frame_num, 10] will be created.
expression (Union[np.ndarray, torch.Tensor, None], optional):
A tensor or ndarray for expression,
in shape [frame_num, expression_dim].
Defaults to None,
zero-tensor in shape [frame_num, 10] will be created.
mask (Union[np.ndarray, torch.Tensor, None], optional):
A tensor or ndarray for framewise visibility mask,
in shape [n_frame, ].
Defaults to None,
one-tensor in shape [n_frame, ] will be created.
logger (Union[None, str, logging.Logger], optional):
Logger for logging. If None, root logger will be selected.
Defaults to None.
"""
SMPLData.__init__(
self,
gender=gender,
transl=transl,
fullpose=fullpose,
betas=betas,
mask=mask,
logger=logger)
if expression is None and 'expression' not in self:
expression = np.zeros(shape=(self.get_batch_size(), 10))
if expression is not None:
self.set_expression(expression)
self.body_joints_num = self.__class__.DEFAULT_BODY_JOINTS_NUM
[docs] @classmethod
def from_dict(cls, smpl_data_dict: Union['SMPLXData',
dict]) -> 'SMPLXData':
"""Construct a body model data structure from a SMPLXData, or a
degraded smplx_data in dict type.
Args:
smplx_data_dict (dict):
A degraded smplx_data in dict type.
Returns:
SMPLXData:
A SMPLXData instance load from dict.
"""
smplx_data_dict = smpl_data_dict
min_keys = {'gender', 'fullpose', 'transl', 'betas', 'expression'}
assert min_keys <= smplx_data_dict.keys()
ret_instance = cls(
gender=smplx_data_dict['gender'],
fullpose=smplx_data_dict['fullpose'],
transl=smplx_data_dict['transl'],
betas=smplx_data_dict['betas'],
expression=smplx_data_dict['expression'])
return ret_instance
[docs] @classmethod
def get_fullpose_dim(cls) -> int:
"""Get dimension of full pose.
Returns:
int:
Dim value. Full pose shall be
in shape (frame_n, dim, 3)
"""
global_orient_dim = 1
ret_sum = global_orient_dim + cls.BODY_POSE_LEN + \
cls.JAW_POSE_LEN + 2 * cls.EYE_POSE_LEN + \
2 * cls.HAND_POSE_LEN
return ret_sum
[docs] def set_expression(self, expression: Union[np.ndarray,
torch.Tensor]) -> None:
"""Set expression data.
Args:
expression (Union[np.ndarray, torch.Tensor]):
Expression parameters in ndarray or tensor,
in shape [batch_size, n].
n stands for any positive int, typically it's 10.
Raises:
TypeError: Type of expression is not correct.
"""
if isinstance(expression, torch.Tensor):
expression = expression.detach().cpu().numpy()
elif not isinstance(expression, np.ndarray):
self.logger.error('Type of expression is not correct.\n' +
f'Type: {type(expression)}.')
raise TypeError
if len(expression.shape) == 1:
expression = expression[np.newaxis, ...]
expression_dim = expression.shape[-1]
expression_np = expression.reshape(-1, expression_dim)
dict.__setitem__(self, 'expression', expression_np)
[docs] def get_expression(self, repeat_expression: bool = True) -> np.ndarray:
"""Get expression.
Args:
repeat_expression (bool, optional):
Whether to repeat expression when its first dim doesn't match
batch_size. Defaults to True.
Returns:
ndarray:
expression in shape [batch_size, expression_dims] or
[1, expression_dims].
"""
batch_size = self.get_global_orient().shape[0]
expression = self.__getitem__('expression')
if repeat_expression and\
expression.shape[0] == 1 and\
expression.shape[0] != batch_size:
expression = expression.repeat(repeats=batch_size, axis=0)
return expression
def __setitem__(self, __k: Any, __v: Any) -> None:
"""Set item according to its key.
Args:
__k (Any): Key in dict.
__v (Any): Value in dict.
"""
if __k == 'expression':
self.set_expression(__v)
else:
SMPLData.__setitem__(self, __k, __v)
[docs] def from_param_dict(self, smplx_dict: dict) -> None:
"""Load SMPLX parameters from smplx_dict, which is the output of a body
model in most cases.
Args:
smplx_dict (dict):
A dict of ndarray|Tensor parameters.
global_orient and body_pose are necessary,
expression, jaw_pose, leye_pose, reye_pose,
left_hand_pose, right_hand_pose, transl and
betas are optional.
Other keys are ignored.
Raises:
KeyError: missing necessary keys.
"""
necessary_keys = {'global_orient', 'body_pose'}
if not necessary_keys.issubset(smplx_dict):
self.logger.error('Keys are not enough.\n' +
f'smplx_dict\'s keys: {smplx_dict.keys()}')
raise KeyError
global_orient = smplx_dict['global_orient']
batch_size = global_orient.shape[0] \
if len(global_orient.shape) > 1 else 1
global_orient = smplx_dict['global_orient'].reshape(batch_size, 3)
body_pose = smplx_dict['body_pose'].reshape(batch_size, -1)
if isinstance(global_orient, torch.Tensor):
def concat_func(data_list, dim):
return torch.cat(data_list, dim=dim)
def zeros_func(shape, ref_data):
return torch.zeros(
size=shape, dtype=ref_data.dtype, device=ref_data.device)
elif isinstance(global_orient, np.ndarray):
def concat_func(data_list, dim):
return np.concatenate(data_list, axis=dim)
def zeros_func(shape, ref_data):
return np.zeros(shape=shape, dtype=ref_data.dtype)
jaw_pose = smplx_dict['jaw_pose'].reshape(batch_size, -1) \
if 'jaw_pose' in smplx_dict else\
zeros_func(
shape=[batch_size, self.__class__.JAW_POSE_LEN*3],
ref_data=global_orient)
leye_pose = smplx_dict['leye_pose'].reshape(batch_size, -1) \
if 'leye_pose' in smplx_dict else\
zeros_func(
shape=[batch_size, self.__class__.EYE_POSE_LEN*3],
ref_data=global_orient)
reye_pose = smplx_dict['reye_pose'].reshape(batch_size, -1) \
if 'reye_pose' in smplx_dict else\
zeros_func(
shape=[batch_size, self.__class__.EYE_POSE_LEN*3],
ref_data=global_orient)
if 'left_hand_pose' in smplx_dict:
if smplx_dict['left_hand_pose'].reshape(
batch_size, -1, 3).shape[1] == \
self.__class__.HAND_POSE_LEN:
left_hand_pose = smplx_dict['left_hand_pose'].reshape(
batch_size, -1)
else:
left_hand_pose = zeros_func(
shape=[batch_size, self.__class__.HAND_POSE_LEN * 3],
ref_data=global_orient)
self.logger.warning(
'SMPLX is using pca for hands,' +
' left_hand_pose in SMPLXData will be set to zeros.')
else:
left_hand_pose = zeros_func(
shape=[batch_size, self.__class__.HAND_POSE_LEN * 3],
ref_data=global_orient)
if 'right_hand_pose' in smplx_dict:
if smplx_dict['right_hand_pose'].reshape(
batch_size, -1, 3).shape[1] == \
self.__class__.HAND_POSE_LEN:
right_hand_pose = smplx_dict['right_hand_pose'].reshape(
batch_size, -1)
else:
right_hand_pose = zeros_func(
shape=[batch_size, self.__class__.HAND_POSE_LEN * 3],
ref_data=global_orient)
self.logger.warning(
'SMPLX is using pca for hands,' +
' right_hand_pose in SMPLXData will be set to zeros.')
else:
right_hand_pose = zeros_func(
shape=[batch_size, self.__class__.HAND_POSE_LEN * 3],
ref_data=global_orient)
fullpose = concat_func([
global_orient, body_pose, jaw_pose, leye_pose, reye_pose,
left_hand_pose, right_hand_pose
],
dim=1)
self.set_fullpose(fullpose)
if 'transl' in smplx_dict:
self.set_transl(smplx_dict['transl'])
if 'betas' in smplx_dict:
self.set_betas(smplx_dict['betas'])
if 'expression' in smplx_dict:
self.set_expression(smplx_dict['expression'])
[docs] def to_param_dict(self,
repeat_betas: bool = True,
repeat_expression: bool = True) -> dict:
"""Split fullpose into global_orient, body_pose, jaw_pose, leye_pose,
reye_pose, left_hand_pose, right_hand_pose, return all the necessary
parameters in one dict.
Args:
repeat_betas (bool, optional):
Whether to repeat betas when its first dim doesn't match
batch_size. Defaults to True.
repeat_expression (bool, optional):
Whether to repeat expression when its first dim doesn't match
batch_size. Defaults to True.
Returns:
dict:
A dict of SMPLX data, whose keys are
betas, global_orient, transl, global_orient, body_pose,
jaw_pose, leye_pose, reye_pose, left_hand_pose,
right_hand_pose, expression.
"""
dict_to_return = SMPLData.to_param_dict(
self, repeat_betas=repeat_betas)
dict_to_return.pop('body_pose')
fullpose = self.get_fullpose()
batch_size = self.get_batch_size()
start_idx = 1
body_pose = fullpose[:, start_idx:start_idx +
self.__class__.BODY_POSE_LEN].reshape(
batch_size, -1)
start_idx += self.__class__.BODY_POSE_LEN
jaw_pose = fullpose[:, start_idx:start_idx +
self.__class__.JAW_POSE_LEN].reshape(-1, 3)
start_idx += self.__class__.JAW_POSE_LEN
leye_pose = fullpose[:, start_idx:start_idx +
self.__class__.EYE_POSE_LEN].reshape(-1, 3)
start_idx += self.__class__.EYE_POSE_LEN
reye_pose = fullpose[:, start_idx:start_idx +
self.__class__.EYE_POSE_LEN].reshape(-1, 3)
start_idx += self.__class__.EYE_POSE_LEN
left_hand_pose = fullpose[:, start_idx:start_idx +
self.__class__.HAND_POSE_LEN].reshape(
batch_size, -1)
start_idx += self.__class__.HAND_POSE_LEN
right_hand_pose = fullpose[:, start_idx:start_idx +
self.__class__.HAND_POSE_LEN].reshape(
batch_size, -1)
expression = self.get_expression(repeat_expression=repeat_expression)
dict_to_return.update({
'body_pose': body_pose,
'jaw_pose': jaw_pose,
'leye_pose': leye_pose,
'reye_pose': reye_pose,
'left_hand_pose': left_hand_pose,
'right_hand_pose': right_hand_pose,
'expression': expression
})
return dict_to_return
[docs] def to_tensor_dict(self,
repeat_betas: bool = True,
repeat_expression: bool = True,
device: Union[torch.device, str] = 'cpu') -> dict:
"""It is almost same as self.to_param_dict, but all the values are
tensors in a specified device. Split fullpose into global_orient and
body_pose, return all the necessary parameters in one dict.
Args:
repeat_betas (bool, optional):
Whether to repeat betas when its first dim doesn't match
batch_size. Defaults to True.
repeat_expression (bool, optional):
Whether to repeat expression when its first dim doesn't match
batch_size. Defaults to True.
device (Union[torch.device, str], optional):
A specified device. Defaults to CPU_DEVICE. Defaults to 'cpu'.
Returns:
dict: A dict of SMPLX data, whose keys are
betas, body_pose, global_orient and transl, etc.
"""
np_dict = self.to_param_dict(
repeat_betas=repeat_betas, repeat_expression=repeat_expression)
dict_to_return = {}
for key, value in np_dict.items():
dict_to_return[key] = torch.tensor(
value, device=device, dtype=torch.float32)
return dict_to_return