Skip to content

Commit

Permalink
add flag to save ply and move save_ply to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
maturk committed Jan 10, 2025
1 parent 1809a6d commit b6c21fd
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 93 deletions.
101 changes: 8 additions & 93 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math
import os
import time
import struct
from dataclasses import dataclass, field
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -42,6 +41,7 @@
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from gsplat.optimizers import SelectiveAdam
from gsplat.utils import save_ply


@dataclass
Expand Down Expand Up @@ -86,6 +86,8 @@ class Config:
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Whether to save ply file (storage size can be large)
save_ply: bool = False
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

Expand Down Expand Up @@ -188,97 +190,6 @@ def adjust_steps(self, factor: float):
assert_never(strategy)


def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None):
# Convert all tensors to numpy arrays in one go
print(f"Saving ply to {dir}")
numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()}

means = numpy_data["means"]
scales = numpy_data["scales"]
quats = numpy_data["quats"]
opacities = numpy_data["opacities"]

sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1)
shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1)

# Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays
invalid_mask = (
np.isnan(means).any(axis=1)
| np.isinf(means).any(axis=1)
| np.isnan(scales).any(axis=1)
| np.isinf(scales).any(axis=1)
| np.isnan(quats).any(axis=1)
| np.isinf(quats).any(axis=1)
| np.isnan(opacities).any(axis=0)
| np.isinf(opacities).any(axis=0)
| np.isnan(sh0).any(axis=1)
| np.isinf(sh0).any(axis=1)
| np.isnan(shN).any(axis=1)
| np.isinf(shN).any(axis=1)
)

# Filter out rows with NaNs or Infs from all data arrays
means = means[~invalid_mask]
scales = scales[~invalid_mask]
quats = quats[~invalid_mask]
opacities = opacities[~invalid_mask]
sh0 = sh0[~invalid_mask]
shN = shN[~invalid_mask]

num_points = means.shape[0]

with open(dir, "wb") as f:
# Write PLY header
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(f"element vertex {num_points}\n".encode())
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
f.write(b"property float nx\n")
f.write(b"property float ny\n")
f.write(b"property float nz\n")

if colors is not None:
for j in range(colors.shape[1]):
f.write(f"property float f_dc_{j}\n".encode())
else:
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
f.write(f"property float {prefix}_{j}\n".encode())

f.write(b"property float opacity\n")

for i in range(scales.shape[1]):
f.write(f"property float scale_{i}\n".encode())
for i in range(quats.shape[1]):
f.write(f"property float rot_{i}\n".encode())

f.write(b"end_header\n")

# Write vertex data
for i in range(num_points):
f.write(struct.pack("<fff", *means[i])) # x, y, z
f.write(struct.pack("<fff", 0, 0, 0)) # nx, ny, nz (zeros)

if colors is not None:
color = colors.detach().cpu().numpy()
for j in range(color.shape[1]):
f_dc = (color[i, j] - 0.5) / 0.2820947917738781
f.write(struct.pack("<f", f_dc))
else:
for data in [sh0, shN]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))

f.write(struct.pack("<f", opacities[i])) # opacity

for data in [scales, quats]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))


def create_splats_with_optimizers(
parser: Parser,
init_type: str = "sfm",
Expand Down Expand Up @@ -832,7 +743,11 @@ def train(self):
torch.save(
data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt"
)
if step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1:
if (
step in [i - 1 for i in cfg.ply_steps]
or step == max_steps - 1
and cfg.save_ply
):
rgb = None
if self.cfg.app_opt:
# eval at origin to bake the appeareance into the colors
Expand Down
93 changes: 93 additions & 0 deletions gsplat/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,101 @@
import math
import struct

import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np


def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None):
# Convert all tensors to numpy arrays in one go
print(f"Saving ply to {dir}")
numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()}

means = numpy_data["means"]
scales = numpy_data["scales"]
quats = numpy_data["quats"]
opacities = numpy_data["opacities"]

sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1)
shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1)

# Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays
invalid_mask = (
np.isnan(means).any(axis=1)
| np.isinf(means).any(axis=1)
| np.isnan(scales).any(axis=1)
| np.isinf(scales).any(axis=1)
| np.isnan(quats).any(axis=1)
| np.isinf(quats).any(axis=1)
| np.isnan(opacities).any(axis=0)
| np.isinf(opacities).any(axis=0)
| np.isnan(sh0).any(axis=1)
| np.isinf(sh0).any(axis=1)
| np.isnan(shN).any(axis=1)
| np.isinf(shN).any(axis=1)
)

# Filter out rows with NaNs or Infs from all data arrays
means = means[~invalid_mask]
scales = scales[~invalid_mask]
quats = quats[~invalid_mask]
opacities = opacities[~invalid_mask]
sh0 = sh0[~invalid_mask]
shN = shN[~invalid_mask]

num_points = means.shape[0]

with open(dir, "wb") as f:
# Write PLY header
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(f"element vertex {num_points}\n".encode())
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
f.write(b"property float nx\n")
f.write(b"property float ny\n")
f.write(b"property float nz\n")

if colors is not None:
for j in range(colors.shape[1]):
f.write(f"property float f_dc_{j}\n".encode())
else:
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
f.write(f"property float {prefix}_{j}\n".encode())

f.write(b"property float opacity\n")

for i in range(scales.shape[1]):
f.write(f"property float scale_{i}\n".encode())
for i in range(quats.shape[1]):
f.write(f"property float rot_{i}\n".encode())

f.write(b"end_header\n")

# Write vertex data
for i in range(num_points):
f.write(struct.pack("<fff", *means[i])) # x, y, z
f.write(struct.pack("<fff", 0, 0, 0)) # nx, ny, nz (zeros)

if colors is not None:
color = colors.detach().cpu().numpy()
for j in range(color.shape[1]):
f_dc = (color[i, j] - 0.5) / 0.2820947917738781
f.write(struct.pack("<f", f_dc))
else:
for data in [sh0, shN]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))

f.write(struct.pack("<f", opacities[i])) # opacity

for data in [scales, quats]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))


def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:
Expand Down

0 comments on commit b6c21fd

Please sign in to comment.