Source code for xrmocap.ops.projection.pytorch_projector

# yapf: disable
import logging
import torch
from typing import List, Union
from xrprimer.data_structure.camera import FisheyeCameraParameter
from xrprimer.ops.projection.base_projector import BaseProjector
from xrprimer.utils.log_utils import get_logger

# yapf: enable


[docs]class PytorchProjector(BaseProjector): CAMERA_CONVENTION = 'opencv' CAMERA_WORLD2CAM = True def __init__(self, camera_parameters: List[FisheyeCameraParameter], logger: Union[None, str, logging.Logger] = None) -> None: """PytorchProjector for points projection. Args: camera_parameters (List[FisheyeCameraParameter]): A list of FisheyeCameraParameter. logger (Union[None, str, logging.Logger], optional): Defaults to None. """ BaseProjector.__init__(self, camera_parameters) self.logger = get_logger(logger)
[docs] def project(self, points: torch.Tensor, points_mask: torch.tensor = None) -> torch.Tensor: """Project points with self.camera_parameters. Args: points (torch.Tensor): points3d, in shape [n_point, 3]. points_mask (torch.Tensor, optional): mask, in shape [n_point, 1]. If points_mask[index] == 1, points[index] is valid for projection, else it is ignored. Defaults to None. Returns: torch.Tensor: points2d, in shape [n_view, n_point, 2]. """ points3d = points[..., :3].reshape(-1, 3).float() n_point = points3d.shape[0] n_view = len(self.camera_parameters) points2d = torch.zeros((n_view, n_point, 2), dtype=points3d.dtype) points_mask = points_mask.reshape(-1) \ if points_mask is not None \ else torch.ones(n_point, dtype=torch.uint8) valid_idxs = torch.where(points_mask == 1) mview_project_mat = self.prepare_project_mat() points3d = points3d[valid_idxs[0], :].T points3d_homo = torch.cat( (points3d, torch.ones(points3d.shape[-1]).reshape(1, -1)), dim=0) points2d_homo = mview_project_mat @ points3d_homo points2d_homo = points2d_homo.transpose(2, 1) proj_points2d = points2d_homo[..., :2] / ( points2d_homo[..., 2:3] + 1e-5) points2d[:, valid_idxs[0], :] = proj_points2d return points2d
[docs] def project_single_point(self, points: torch.Tensor) -> torch.Tensor: """Project a single point with self.camera_parameters. Args: points (torch.Tensor): points3d, in shape [3]. Returns: torch.Tensor: points2d, in shape [n_view, 2]. """ points3d = points.reshape(1, 3) return torch.squeeze(self.project(points3d), dim=1)
def prepare_project_mat(self): n_view = len(self.camera_parameters) mview_project_mat = torch.zeros((n_view, 3, 4)) K = torch.zeros((n_view, 3, 3)) RT = torch.zeros((n_view, 3, 4)) for i, cam_param in enumerate(self.camera_parameters): K[i] = torch.tensor(cam_param.get_intrinsic(k_dim=3)) RT[i] = torch.cat( (torch.tensor(cam_param.extrinsic_r), torch.tensor(cam_param.extrinsic_t).unsqueeze(1)), dim=1) # compute projection matrix mview_project_mat[i] = K[i] @ RT[i] return mview_project_mat