Skip to content

Commit 6ce7843

Browse files
authored
add camera converter for gsplats (#813)
* add camera converter Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com> * fix few tests tolerance / flakiness Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com> --------- Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>
1 parent 741b8c0 commit 6ce7843

9 files changed

Lines changed: 184 additions & 4 deletions

File tree

docs/kaolin_ext.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def run_apidoc(_):
9393
"kaolin/render/camera/intrinsics.py",
9494
"kaolin/render/camera/legacy.py",
9595
"kaolin/render/camera/raygen.py",
96+
"kaolin/render/camera/gsplats.py",
9697
"kaolin/non_commercial/flexicubes/",
9798
"kaolin/non_commercial/flexicubes/flexicubes.py",
9899
"kaolin/non_commercial/flexicubes/tables.py"

kaolin/render/camera/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@
2121
from .coordinates import *
2222
from .legacy import *
2323
from .raygen import *
24+
from .gsplats import *
2425

2526
__all__ = [k for k in locals().keys() if not k.startswith('__')]

kaolin/render/camera/gsplats.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
import torch
19+
from .camera import Camera
20+
from .intrinsics import CameraFOV
21+
22+
__all__ = [
23+
'kaolin_camera_to_gsplats',
24+
'gsplats_camera_to_kaolin'
25+
]
26+
27+
28+
29+
def kaolin_camera_to_gsplats(kal_camera, gs_cam_cls):
30+
"""Convert Kaolin Camera to `INRIA gaussian splats`_ camera.
31+
32+
.. note::
33+
This has been tested with the version commit `472689c`_
34+
35+
Args:
36+
kal_camera (Camera): camera to convert.
37+
gs_cam_cls (class):
38+
This is the gsplats ``Camera`` class,
39+
usually located in gsplats/scene/cameras.py.
40+
41+
Returns:
42+
(gsplats.scene.cameras.Camera): converted INRIA gsplats camera.
43+
44+
.. _INRIA gaussian splats:
45+
https://github.com/graphdeco-inria/gaussian-splatting
46+
.. _472689c:
47+
https://github.com/graphdeco-inria/gaussian-splatting/tree/472689c0dc70417448fb451bf529ae532d32c095
48+
"""
49+
R = kal_camera.extrinsics.R[0].clone()
50+
R[1:3] = -R[1:3]
51+
T = kal_camera.extrinsics.t.squeeze()
52+
T[1:3] = -T[1:3]
53+
return gs_cam_cls(colmap_id=0,
54+
R=R.transpose(1, 0).cpu().numpy(),
55+
T=T.cpu().numpy(),
56+
FoVx=kal_camera.fov(CameraFOV.HORIZONTAL, in_degrees=False),
57+
FoVy=kal_camera.fov(CameraFOV.VERTICAL, in_degrees=False),
58+
image=torch.zeros((3, kal_camera.height, kal_camera.width)), # fake
59+
gt_alpha_mask=None,
60+
image_name='fake',
61+
uid=0)
62+
63+
def gsplats_camera_to_kaolin(gs_camera):
64+
"""Convert `INRIA gaussian splats`_ camera to Kaolin camera.
65+
66+
.. note::
67+
This has been tested with the version commit `472689c`_
68+
69+
Args:
70+
gs_camera (gsplats.scene.cameras.Camera): camera to convert.
71+
72+
Returns:
73+
(Camera): converted Kaolin camera.
74+
75+
.. _INRIA gaussian splats:
76+
https://github.com/graphdeco-inria/gaussian-splatting
77+
.. _472689c:
78+
https://github.com/graphdeco-inria/gaussian-splatting/tree/472689c0dc70417448fb451bf529ae532d32c095
79+
"""
80+
view_mat = gs_camera.world_view_transform.transpose(1, 0).clone()
81+
view_mat[1:3] = -view_mat[1:3]
82+
res = Camera.from_args(
83+
view_matrix=view_mat,
84+
width=gs_camera.image_width, height=gs_camera.image_height,
85+
fov=gs_camera.FoVy, device='cpu')
86+
return res

kaolin/render/lighting/sg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def sg_from_sun(direction, strength, angle, color):
148148
149149
Args:
150150
strength (torch.Tensor):
151-
The strength of the suns, of shape :math:`(\text{num_suns},)`, [0..inf] expected,
151+
The strength of the suns, of shape :math:`(\text{num_suns},)`, [1..inf] expected,
152152
usually in low integer range.
153153
color (torch.Tensor):
154154
The color of the suns,

tests/python/kaolin/ops/mesh/test_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_average_face_vertex_features(self, device, dtype):
170170

171171
vertex_features = mesh.average_face_vertex_features(faces, face_features)
172172
if dtype == torch.half:
173-
assert torch.allclose(expected, vertex_features, rtol=1e-03, atol=1e-05)
173+
assert torch.allclose(expected, vertex_features, rtol=1e-3, atol=1e-3)
174174
else:
175175
assert torch.allclose(expected, vertex_features)
176176

tests/python/kaolin/ops/test_pointcloud.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_center_points(device, dtype):
3535
points[:, 0, :] = 1.0 # make sure 1 is included
3636
points[:, 1, :] = 0.0 # make sure 0 is included
3737
points = points - 0.5 # -0.5...0.5
38+
points = torch.clamp((torch.sign(points) * 1e-3) + points, -0.5, 0.5)
3839

3940
factors = 0.2 + 2 * torch.rand((B, 1, 1), device=device, dtype=dtype)
4041
translations = torch.rand((B, 1, 3), device=device, dtype=dtype) - 0.5
@@ -66,4 +67,4 @@ def test_center_points(device, dtype):
6667
points[0, :, :] = torch.tensor([0, 2., 4.], dtype=dtype, device=device).reshape((1, 3))
6768
points_centered = kaolin.ops.pointcloud.center_points(points * factors + translations, normalize=True)
6869
points[0, :, :] = 0
69-
assert torch.allclose(points, points_centered, atol=atol, rtol=rtol)
70+
assert torch.allclose(points, points_centered, atol=atol, rtol=rtol)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import os
18+
import sys
19+
import subprocess
20+
import pytest
21+
import math
22+
import random
23+
import torch
24+
import shutil
25+
import numpy as np
26+
from git import Repo
27+
from kaolin.render.camera import kaolin_camera_to_gsplats, gsplats_camera_to_kaolin, Camera
28+
29+
# dealing with nvcr
30+
if torch.version.cuda == '12.5':
31+
pytest.skip("gsplats is not installable with CUDA 12.5", allow_module_level=True)
32+
33+
ROOT_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'gsplats')
34+
@pytest.fixture(scope="module")
35+
def gs_cam_cls():
36+
repo = Repo.clone_from(
37+
url='https://github.com/graphdeco-inria/gaussian-splatting',
38+
multi_options=['--recursive'],
39+
to_path=ROOT_DIR
40+
)
41+
sys.path.append(ROOT_DIR)
42+
subprocess.check_call([
43+
sys.executable, "-m", "pip", "install",
44+
os.path.join(ROOT_DIR, "submodules", "diff-gaussian-rasterization")
45+
])
46+
subprocess.check_call([
47+
sys.executable, "-m", "pip", "install",
48+
os.path.join(ROOT_DIR, "submodules", "simple-knn")
49+
])
50+
from .gsplats.scene.cameras import Camera as GSCamera
51+
52+
yield GSCamera
53+
sys.path.remove(ROOT_DIR)
54+
shutil.rmtree(ROOT_DIR)
55+
56+
57+
58+
class TestGsplats:
59+
def test_cycle(self, gs_cam_cls):
60+
kal_cam = Camera.from_args(
61+
eye=torch.rand((3,)),
62+
at=torch.rand((3,)),
63+
up=torch.nn.functional.normalize(torch.rand((3,)), dim=0),
64+
fov=random.random() * math.pi,
65+
width=512, height=512,
66+
)
67+
gs_cam = kaolin_camera_to_gsplats(kal_cam, gs_cam_cls)
68+
out_cam = gsplats_camera_to_kaolin(gs_cam)
69+
assert torch.allclose(out_cam, kal_cam)
70+
71+
def test_kaolin_to_gsplats_regression(self, gs_cam_cls):
72+
kal_cam = Camera.from_args(
73+
eye=torch.tensor([1., 2., 3.]),
74+
at=torch.tensor([0.3, 0.1, 0.2]),
75+
up=torch.tensor([0., 1., 0.]),
76+
fov=math.pi / 4,
77+
width=512, height=512,
78+
)
79+
gs_cam = kaolin_camera_to_gsplats(kal_cam, gs_cam_cls)
80+
expected_R = np.array([[ 0.9701425, 0.13336042, -0.20257968],
81+
[ 0., -0.83525735, -0.5498591 ],
82+
[-0.24253562, 0.53344166, -0.8103187 ]])
83+
expected_T = np.array([-0.24253559, -0.06317067, 3.733254 ])
84+
expected_fovx = torch.tensor([0.7854])
85+
expected_fovy = torch.tensor([0.7854])
86+
assert np.allclose(expected_R, gs_cam.R)
87+
assert np.allclose(expected_T, gs_cam.T)
88+
assert torch.allclose(expected_fovx, gs_cam.FoVx)
89+
assert torch.allclose(expected_fovy, gs_cam.FoVy)

tests/python/kaolin/render/lighting/test_sg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def test_from_sun(self):
420420

421421
class TestUtilities:
422422
def test_sg_from_sun(self):
423-
strength = torch.rand((7,), dtype=torch.float) * 10
423+
strength = torch.rand((7,), dtype=torch.float) * 10 + 1.
424424
direction = torch.rand((7, 3), dtype=torch.float)
425425
angle = torch.rand((7,), dtype=torch.float) * math.pi
426426
color = torch.rand((7, 3), dtype=torch.float)

tools/ci_requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ pytest==7.3.0
77
pytest-cov==3.0.0
88
nbmake==1.4.1
99
jupyter
10+
plyfile
11+
gitpython

0 commit comments

Comments
 (0)