Source code for xrmocap.data_structure.limbs

# yapf: disable
import logging
import numpy as np
import torch
from typing import List, Union
from xrprimer.utils.log_utils import get_logger

# yapf: enable


[docs]class Limbs(): """A class for person limbs data, recording connection vectors between keypoints. Connections are the only necessary data, while human parts, points are optional. """ def __init__(self, connections: Union[np.ndarray, torch.Tensor], connection_names: Union[List[str], None] = None, parts: Union[List[List[int]], None] = None, part_names: Union[List[str], None] = None, points: Union[np.ndarray, torch.Tensor, None] = None, logger: Union[None, str, logging.Logger] = None) -> None: """A class for person limbs data, recording connection vectors between keypoints. Connections are the only necessary data. Connections record point indice, while parts record connection indice. Args: connections (Union[np.ndarray, torch.Tensor]): A tensor or ndarray for connections, in shape [n_conn, 2], conn[:, 0] are start point indice and conn[:, 1] are end point indice. connection_names (Union[List[str], None], optional): A list of connections names. If given, len(connection_names)==len(conn), else default names will be returned when getting connections. Defaults to None. parts (Union[List[List[int]], None], optional): A nested list, len(parts) is part number, and len(parts[0]) is connection number of the first part. Each element in parts[i] is an index of one connection. part_names (Union[List[str], None], optional): A list of part names. If given, len(part_names)==len(parts), else default names will be returned when getting parts. Defaults to None. points (Union[np.ndarray, torch.Tensor, None], optional): A tensor or ndarray for points, in shape [n_point, point_dim]. Defaults to None. logger (Union[None, str, logging.Logger], optional): Logger for logging. If None, root logger will be selected. Defaults to None. """ self.logger = get_logger(logger) self.connections = None self.connection_names = None self.parts = None self.part_names = None self.points = None self.logger = get_logger(None) self.set_connections(connections, connection_names) if parts is not None: self.set_parts(parts, part_names) if points is not None: self.set_points(points)
[docs] def set_connections(self, conn: Union[np.ndarray, torch.Tensor], conn_names: List[str] = None) -> None: """Set connection relations of the limbs. Names are optional. Args: conn (Union[np.ndarray, torch.Tensor]): A tensor or ndarray for connections, in shape [n_conn, 2], conn[:, 0] are start point indice and conn[:, 1] are end point indice. conn_names (List[str], optional): A list of connections names. If given, len(conn_names)==len(conn), else default names will be returned when getting connections. Defaults to None. Raises: TypeError: Type of connections is not correct. ValueError: Shape of connections is not correct. """ if isinstance(conn, torch.Tensor): connections = conn.detach().cpu().numpy() elif isinstance(conn, np.ndarray): connections = conn else: self.logger.error('Type of connections is not correct.\n' + f'Type: {type(conn)}.') raise TypeError connections = connections.astype(dtype=np.int32) # shape: connection_number, 2 if connections.shape[-1] != 2 or len(connections.shape) != 2: self.logger.error('Shape of connections should be' + ' [n_conn, 2].\n' + f'connections.shape: {connections.shape}.') raise ValueError self.connections = connections if conn_names is not None: if len(conn_names) == len(connections): self.connection_names = conn_names else: self.logger.warning( 'Length of connection_names is wrong, reset to None.\n' + f'len(conn_names): {len(conn_names)}\n' + f'len(connections): {len(connections)}') self.connection_names = None else: self.connection_names = None
[docs] def set_parts(self, parts: List[List[int]], part_names: List[str] = None) -> None: """Set parts of the limbs. If parts has been set, connections can be arranged by part when getting. Names are optional. Args: parts (List[List[int]]): A nested list, len(parts) is part number, and len(parts[0]) is connection number of the first part. Each element in parts[i] is an index of one connection. part_names (List[str], optional): A list of part names. If given, len(part_names)==len(parts), else default names will be returned when getting parts. Defaults to None. Raises: TypeError: Type of parts is not correct. ValueError: Type of connection index is not correct. """ if not isinstance(parts, list): self.logger.error('Type of parts is not correct.\n' + f'Type: {type(parts)}.\n' + 'Expect: list') raise TypeError # Type of conn index: int for conn_list in parts: for conn_index in conn_list: if not isinstance(conn_index, int): self.logger.error( 'Type of connection index is not correct.\n' + f'Type: {type(conn_index)}.\n' + 'Expect: int') raise TypeError self.parts = parts if part_names is not None: if len(part_names) == len(parts): self.part_names = part_names else: self.logger.warning( 'Length of part_names is wrong, reset to None.\n' + f'len(part_names): {len(part_names)}\n' + f'len(parts): {len(part_names)}') self.part_names = None else: self.part_names = None
[docs] def set_points(self, points: Union[np.ndarray, torch.Tensor]) -> None: """Set points of the limbs. Args: points (Union[np.ndarray, torch.Tensor]): A tensor or ndarray for points, in shape [n_point, point_dim]. Raises: TypeError: Type of points is not correct. ValueError: Shape of points should be [n_point, point_dim]. """ if isinstance(points, torch.Tensor): points = points.detach().cpu().numpy() elif not isinstance(points, np.ndarray): self.logger.error('Type of points is not correct.\n' + f'Type: {type(points)}.') raise TypeError # shape: connection_number, 2 if len(points.shape) != 2: self.logger.error('Shape of points should be' + ' [n_point, point_dim].\n' + f'points.shape: {points.shape}.') raise ValueError self.points = points
def clone(self) -> 'Limbs': def copy_if_not_None(data): if data is None: return None else: return data.copy() ret_limbs = Limbs( connections=self.connections.copy(), connection_names=self.connection_names, parts=copy_if_not_None(self.parts), part_names=self.part_names, points=copy_if_not_None(self.points), logger=self.logger) return ret_limbs
[docs] def get_points(self) -> Union[np.ndarray, None]: """Get points array, which might be None. Returns: Union[np.ndarray, None]: keypoints """ return self.points
[docs] def get_connections(self) -> np.ndarray: """Get connections of limbs. Returns: np.ndarray: connections """ return self.connections
[docs] def get_connection_names(self) -> List[str]: """Get names of connection. Returns: List[str]: A list of names """ if self.connection_names is None: ret_list = [ f'conn_{conn_index:03d}' for conn_index in range(len(self.connections)) ] else: ret_list = self.connection_names return ret_list
[docs] def get_parts(self) -> Union[List[List[int]], None]: """Get parts of limbs, which might be None. Returns: Union[List[List[int]], None]: parts """ return self.parts
[docs] def get_part_names(self) -> List[str]: """Get names of part. If self.part is None, an empty list will be returned. Returns: List[str]: A list of names """ if self.parts is None: ret_list = [] else: if self.part_names is None: ret_list = [ f'part_{part_index:03d}' for part_index in range(len(self.parts)) ] else: ret_list = self.part_names return ret_list
def __len__(self) -> int: """Get number of connections. Returns: int """ return len(self.connections)
[docs] def get_connections_in_parts(self) -> dict: """Get connection in parts. In each part, there's a list of. [start_pt_index, end_pt_index]. Returns: dict: keys are part names. """ ret_dict = {} part_names = self.get_part_names() for part_index, part_name in enumerate(part_names): connection_list = self.parts[part_index] ret_dict[part_name] = connection_list return ret_dict
[docs] def get_connections_by_names(self) -> dict: """Get connection by names. Returns: dict: keys are connection names and values are [start_pt_index, end_pt_index], in type ndarray. """ ret_dict = {} conn_names = self.get_connection_names() for conn_index, conn_name in enumerate(conn_names): connection = self.connections[conn_index] ret_dict[conn_name] = connection return ret_dict