From b6c21fd516a2a14eb6a7be0d1521b64f686872ce Mon Sep 17 00:00:00 2001 From: maturk Date: Fri, 10 Jan 2025 21:42:15 +0200 Subject: [PATCH] add flag to save ply and move save_ply to utils --- examples/simple_trainer.py | 101 +++---------------------------------- gsplat/utils.py | 93 ++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 93 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 81576ae04..ca9271e81 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -2,7 +2,6 @@ import math import os import time -import struct from dataclasses import dataclass, field from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -42,6 +41,7 @@ from gsplat.rendering import rasterization from gsplat.strategy import DefaultStrategy, MCMCStrategy from gsplat.optimizers import SelectiveAdam +from gsplat.utils import save_ply @dataclass @@ -86,6 +86,8 @@ class Config: eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Whether to save ply file (storage size can be large) + save_ply: bool = False # Steps to save the model as ply ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -188,97 +190,6 @@ def adjust_steps(self, factor: float): assert_never(strategy) -def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None): - # Convert all tensors to numpy arrays in one go - print(f"Saving ply to {dir}") - numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()} - - means = numpy_data["means"] - scales = numpy_data["scales"] - quats = numpy_data["quats"] - opacities = numpy_data["opacities"] - - sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1) - shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1) - - # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays - invalid_mask = ( - np.isnan(means).any(axis=1) - | np.isinf(means).any(axis=1) - | np.isnan(scales).any(axis=1) - | np.isinf(scales).any(axis=1) - | np.isnan(quats).any(axis=1) - | np.isinf(quats).any(axis=1) - | np.isnan(opacities).any(axis=0) - | np.isinf(opacities).any(axis=0) - | np.isnan(sh0).any(axis=1) - | np.isinf(sh0).any(axis=1) - | np.isnan(shN).any(axis=1) - | np.isinf(shN).any(axis=1) - ) - - # Filter out rows with NaNs or Infs from all data arrays - means = means[~invalid_mask] - scales = scales[~invalid_mask] - quats = quats[~invalid_mask] - opacities = opacities[~invalid_mask] - sh0 = sh0[~invalid_mask] - shN = shN[~invalid_mask] - - num_points = means.shape[0] - - with open(dir, "wb") as f: - # Write PLY header - f.write(b"ply\n") - f.write(b"format binary_little_endian 1.0\n") - f.write(f"element vertex {num_points}\n".encode()) - f.write(b"property float x\n") - f.write(b"property float y\n") - f.write(b"property float z\n") - f.write(b"property float nx\n") - f.write(b"property float ny\n") - f.write(b"property float nz\n") - - if colors is not None: - for j in range(colors.shape[1]): - f.write(f"property float f_dc_{j}\n".encode()) - else: - for i, data in enumerate([sh0, shN]): - prefix = "f_dc" if i == 0 else "f_rest" - for j in range(data.shape[1]): - f.write(f"property float {prefix}_{j}\n".encode()) - - f.write(b"property float opacity\n") - - for i in range(scales.shape[1]): - f.write(f"property float scale_{i}\n".encode()) - for i in range(quats.shape[1]): - f.write(f"property float rot_{i}\n".encode()) - - f.write(b"end_header\n") - - # Write vertex data - for i in range(num_points): - f.write(struct.pack(" Tensor: