Skip to content

Commit

Permalink
More optimizations for Gaussian rendering (#257)
Browse files Browse the repository at this point in the history
* More optimizations for Gaussian rendering

* Formatting nits
  • Loading branch information
brentyi authored Jul 29, 2024
1 parent 86796a1 commit a52d9a6
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 73 deletions.
4 changes: 4 additions & 0 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,10 +980,14 @@ def _add_gaussian_splats(
buffer = onp.concatenate(
[
# First texelFetch.
# - xyz (96 bits): centers.
centers.astype(onp.float32).view(onp.uint8),
# - w (32 bits): this is reserved for use by the renderer.
onp.zeros((num_gaussians, 4), dtype=onp.uint8),
# Second texelFetch.
# - xyz (96 bits): upper-triangular Cholesky factor of covariance.
cov_cholesky_triu.astype(onp.float16).copy().view(onp.uint8),
# - w (32 bits): rgba.
_colors_to_uint8(rgbs),
_colors_to_uint8(opacities),
],
Expand Down
26 changes: 21 additions & 5 deletions src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ import "./App.css";
import { Notifications } from "@mantine/notifications";

import {
AdaptiveDpr,
AdaptiveEvents,
CameraControls,
Environment,
PerformanceMonitor,
} from "@react-three/drei";
import * as THREE from "three";
import { Canvas, useThree, useFrame } from "@react-three/fiber";
Expand Down Expand Up @@ -301,7 +300,6 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) {
width: "100%",
height: "100%",
}}
performance={{ min: 0.95 }}
ref={viewer.canvasRef}
// Handle scene click events (onPointerDown, onPointerMove, onPointerUp)
onPointerDown={(e) => {
Expand Down Expand Up @@ -449,8 +447,7 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) {
>
{children}
<BackgroundImage />
<AdaptiveDpr pixelated />
<AdaptiveEvents />
<AdaptiveDpr />
<SceneContextSetter />
<SynchronizedCameraControls />
<SceneNodeThreeObject name="" parent={null} />
Expand All @@ -465,6 +462,25 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) {
);
}

function AdaptiveDpr() {
const setDpr = useThree((state) => state.setDpr);
return (
<PerformanceMonitor
factor={0.5}
ms={100}
iterations={2}
step={0.1}
onChange={({ factor, fps, refreshrate }) => {
const dpr = window.devicePixelRatio * (0.2 + 0.8 * factor);
console.log(
`[Performance] Setting DPR to ${dpr}; FPS=${fps}/${refreshrate}`,
);
setDpr(dpr);
}}
/>
);
}

/* HTML Canvas, for drawing 2D. */
function Viewer2DCanvas() {
const viewer = React.useContext(ViewerContext)!;
Expand Down
57 changes: 21 additions & 36 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
// Index from the splat sorter.
attribute uint sortedIndex;
// Which group transform should be applied to each Gaussian.
attribute uint sortedGroupIndex;
// Buffers for splat data; each Gaussian gets 4 floats and 4 int32s. We just
// copy quadjr for this.
uniform usampler2D textureBuffer;
Expand Down Expand Up @@ -87,7 +84,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
// Fetch from textures.
uvec4 floatBufferData = texelFetch(textureBuffer, texPos0, 0);
mat4 T_camera_group = getGroupTransform(sortedGroupIndex);
mat4 T_camera_group = getGroupTransform(floatBufferData.w);
// Any early return will discard the fragment.
gl_Position = vec4(0.0, 0.0, 2000.0, 1.0);
Expand Down Expand Up @@ -197,7 +194,6 @@ export default function GlobalGaussianSplats() {
const merged = mergeGaussianGroups(groupBufferFromName);
const meshProps = useGaussianMeshProps(
merged.gaussianBuffer,
merged.groupIndices,
merged.numGroups,
);

Expand All @@ -210,13 +206,6 @@ export default function GlobalGaussianSplats() {
meshProps.sortedIndexAttribute.set(sortedIndices);
meshProps.sortedIndexAttribute.needsUpdate = true;

// Update group indices if needed.
if (merged.numGroups >= 2) {
const sortedGroupIndices = e.data.sortedGroupIndices as Uint32Array;
meshProps.sortedGroupIndexAttribute.set(sortedGroupIndices);
meshProps.sortedGroupIndexAttribute.needsUpdate = true;
}

// Trigger initial render.
if (!initializedBufferTexture) {
meshProps.material.uniforms.numGaussians.value = merged.numGaussians;
Expand Down Expand Up @@ -278,7 +267,7 @@ export default function GlobalGaussianSplats() {
const T_camera_world = state.camera.matrixWorldInverse;
const groupVisibles: boolean[] = [];
let visibilitiesChanged = false;
for (const [sortedGroupIndex, name] of Object.keys(
for (const [groupIndex, name] of Object.keys(
groupBufferFromName,
).entries()) {
const node = viewer.nodeRefFromName.current[name];
Expand All @@ -292,12 +281,12 @@ export default function GlobalGaussianSplats() {
colMajorElements[10],
colMajorElements[14],
],
sortedGroupIndex * 4,
groupIndex * 4,
);
const rowMajorElements = tmpT_camera_group.transpose().elements;
meshProps.rowMajorT_camera_groups.set(
rowMajorElements.slice(0, 12),
sortedGroupIndex * 12,
groupIndex * 12,
);

// Determine visibility. If the parent has unmountWhenInvisible=true, the
Expand All @@ -310,9 +299,9 @@ export default function GlobalGaussianSplats() {
visibleNow = visibleNow && ancestor.visible;
});
}
groupVisibles.push(visibleNow && prevVisibles[sortedGroupIndex] === true);
if (prevVisibles[sortedGroupIndex] !== visibleNow) {
prevVisibles[sortedGroupIndex] = visibleNow;
groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true);
if (prevVisibles[groupIndex] !== visibleNow) {
prevVisibles[groupIndex] = visibleNow;
visibilitiesChanged = true;
}
}
Expand All @@ -333,7 +322,6 @@ export default function GlobalGaussianSplats() {
// for the shader and not for the sorter; that way when we "show" a group
// of Gaussians the correct rendering order is immediately available.
for (const [i, visible] of groupVisibles.entries()) {
console.log(i, visible);
if (!visible) {
meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10;
meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10;
Expand Down Expand Up @@ -370,15 +358,25 @@ function mergeGaussianGroups(groupBufferFromName: {
const groupIndices = new Uint32Array(numGaussians);

let offset = 0;
for (const [sortedGroupIndex, groupBuffer] of Object.values(
for (const [groupIndex, groupBuffer] of Object.values(
groupBufferFromName,
).entries()) {
groupIndices.fill(
sortedGroupIndex,
groupIndex,
offset / 8,
(offset + groupBuffer.length) / 8,
);
gaussianBuffer.set(groupBuffer, offset);

// Each Gaussian is allocated
// - 12 bytes for center x, y, z (float32)
// - 4 bytes for group index (uint32); we're filling this in now
//
// - 12 bytes for covariance (6 terms, float16)
// - 4 bytes for RGBA (uint8)
for (let i = 0; i < groupBuffer.length; i += 8) {
gaussianBuffer[offset + i + 3] = groupIndex;
}
offset += groupBuffer.length;
}

Expand All @@ -387,12 +385,8 @@ function mergeGaussianGroups(groupBufferFromName: {
}

/**Hook to generate properties for rendering Gaussians via a three.js mesh.*/
function useGaussianMeshProps(
gaussianBuffer: Uint32Array,
groupIndices: Uint32Array,
numGroups: number,
) {
const numGaussians = groupIndices.length;
function useGaussianMeshProps(gaussianBuffer: Uint32Array, numGroups: number) {
const numGaussians = gaussianBuffer.length / 8;
const maxTextureSize = useThree((state) => state.gl).capabilities
.maxTextureSize;

Expand All @@ -418,14 +412,6 @@ function useGaussianMeshProps(
sortedIndexAttribute.setUsage(THREE.DynamicDrawUsage);
geometry.setAttribute("sortedIndex", sortedIndexAttribute);

// Which group is each Gaussian in?
const sortedGroupIndexAttribute = new THREE.InstancedBufferAttribute(
groupIndices.slice(), // Copies the array.
1,
);
sortedGroupIndexAttribute.setUsage(THREE.DynamicDrawUsage);
geometry.setAttribute("sortedGroupIndex", sortedGroupIndexAttribute);

// Create texture buffers.
const textureWidth = Math.min(numGaussians * 2, maxTextureSize);
const textureHeight = Math.ceil((numGaussians * 2) / textureWidth);
Expand Down Expand Up @@ -465,7 +451,6 @@ function useGaussianMeshProps(
material,
textureBuffer,
sortedIndexAttribute,
sortedGroupIndexAttribute,
textureT_camera_groups,
rowMajorT_camera_groups,
};
Expand Down
36 changes: 10 additions & 26 deletions src/viser/client/src/Splatting/SplatSortWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,26 @@ export type SorterWorkerIncoming =
{
let sorter: any = null;
let Tz_camera_groups: Float32Array | null = null;
let groupIndices: Uint32Array | null = null;
let sortedGroupIndices: Uint32Array | null = null;
let sortRunning = false;
const throttledSort = () => {
if (
sorter === null ||
Tz_camera_groups === null ||
groupIndices === null ||
sortedGroupIndices === null
) {
if (sorter === null || Tz_camera_groups === null) {
setTimeout(throttledSort, 1);
return;
}
if (sortRunning) return;

sortRunning = true;
const lastView = Tz_camera_groups;
const sortedIndices = sorter.sort(Tz_camera_groups);

const numGroups = Tz_camera_groups.length / 4;
if (numGroups >= 2) {
// Multiple groups: we need to update the per-Gaussian group indices.
for (const [index, sortedIndex] of sortedIndices.entries()) {
sortedGroupIndices[index] = groupIndices[sortedIndex];
}
self.postMessage({
sortedIndices: sortedIndices,
sortedGroupIndices: sortedGroupIndices,
});
} else {
self.postMessage({
sortedIndices: sortedIndices,
});
}
// Important: we clone the output so we can transfer the buffer to the main
// thread. Compared to relying on postMessage for copying, this reduces
// backlog artifacts.
const sortedIndices = (
sorter.sort(Tz_camera_groups) as Uint32Array
).slice();

// @ts-ignore
self.postMessage({ sortedIndices: sortedIndices }, [sortedIndices.buffer]);

setTimeout(() => {
sortRunning = false;
Expand All @@ -75,8 +61,6 @@ export type SorterWorkerIncoming =
data.setBuffer,
data.setGroupIndices,
);
groupIndices = data.setGroupIndices;
sortedGroupIndices = groupIndices.slice();
} else if ("setTz_camera_groups" in data) {
// Update object transforms.
Tz_camera_groups = data.setTz_camera_groups;
Expand Down
Binary file modified src/viser/client/src/Splatting/WasmSorter/Sorter.wasm
Binary file not shown.
12 changes: 6 additions & 6 deletions src/viser/client/src/Splatting/WasmSorter/sorter.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <emscripten/bind.h>
#include <emscripten/val.h>
#include <wasm_simd128.h>

#include <cstdint>
#include <iostream>
#include <string>
#include <vector>
#include <wasm_simd128.h>

#include <emscripten/bind.h>
#include <emscripten/val.h>

/** SIMD dot product between two 4D vectors. */
__attribute__((always_inline)) inline float
Expand Down Expand Up @@ -41,6 +41,7 @@ __attribute__((always_inline)) inline int32_t max_i32x4(v128_t vector) {
class Sorter {
std::vector<v128_t> centers_homog; // Centers as homogeneous coordinates.
std::vector<uint32_t> group_indices;
std::vector<uint32_t> sorted_indices;

public:
Sorter(
Expand All @@ -51,7 +52,7 @@ class Sorter {
const float *floatBuffer =
reinterpret_cast<const float *>(bufferVec.data());
const int32_t num_gaussians = bufferVec.size() / 8;

sorted_indices.resize(num_gaussians);
centers_homog.resize(num_gaussians);
for (int32_t i = 0; i < num_gaussians; i++) {
centers_homog[i] = wasm_f32x4_make(
Expand Down Expand Up @@ -173,7 +174,6 @@ class Sorter {
}

// Update and return sorted indices.
std::vector<uint32_t> sorted_indices(num_gaussians);
for (int32_t i = 0; i < num_gaussians; i++)
sorted_indices[starts0[((int32_t *)&gaussian_zs[0])[i]]++] = i;
return emscripten::val(emscripten::typed_memory_view(
Expand Down

0 comments on commit a52d9a6

Please sign in to comment.