From a52d9a68657acce2662b3cc4c3e1b6f29a5fffbe Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Mon, 29 Jul 2024 16:47:29 -0700 Subject: [PATCH] More optimizations for Gaussian rendering (#257) * More optimizations for Gaussian rendering * Formatting nits --- src/viser/_scene_api.py | 4 ++ src/viser/client/src/App.tsx | 26 ++++++-- .../client/src/Splatting/GaussianSplats.tsx | 57 +++++++----------- .../client/src/Splatting/SplatSortWorker.ts | 36 +++-------- .../src/Splatting/WasmSorter/Sorter.wasm | Bin 21112 -> 21164 bytes .../src/Splatting/WasmSorter/sorter.cpp | 12 ++-- 6 files changed, 62 insertions(+), 73 deletions(-) diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 6e23a98e..56fce929 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -980,10 +980,14 @@ def _add_gaussian_splats( buffer = onp.concatenate( [ # First texelFetch. + # - xyz (96 bits): centers. centers.astype(onp.float32).view(onp.uint8), + # - w (32 bits): this is reserved for use by the renderer. onp.zeros((num_gaussians, 4), dtype=onp.uint8), # Second texelFetch. + # - xyz (96 bits): upper-triangular Cholesky factor of covariance. cov_cholesky_triu.astype(onp.float16).copy().view(onp.uint8), + # - w (32 bits): rgba. _colors_to_uint8(rgbs), _colors_to_uint8(opacities), ], diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index 2e869357..8b436344 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -6,10 +6,9 @@ import "./App.css"; import { Notifications } from "@mantine/notifications"; import { - AdaptiveDpr, - AdaptiveEvents, CameraControls, Environment, + PerformanceMonitor, } from "@react-three/drei"; import * as THREE from "three"; import { Canvas, useThree, useFrame } from "@react-three/fiber"; @@ -301,7 +300,6 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { width: "100%", height: "100%", }} - performance={{ min: 0.95 }} ref={viewer.canvasRef} // Handle scene click events (onPointerDown, onPointerMove, onPointerUp) onPointerDown={(e) => { @@ -449,8 +447,7 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { > {children} - - + @@ -465,6 +462,25 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { ); } +function AdaptiveDpr() { + const setDpr = useThree((state) => state.setDpr); + return ( + { + const dpr = window.devicePixelRatio * (0.2 + 0.8 * factor); + console.log( + `[Performance] Setting DPR to ${dpr}; FPS=${fps}/${refreshrate}`, + ); + setDpr(dpr); + }} + /> + ); +} + /* HTML Canvas, for drawing 2D. */ function Viewer2DCanvas() { const viewer = React.useContext(ViewerContext)!; diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index 71d814f1..b0bfab44 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -31,9 +31,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // Index from the splat sorter. attribute uint sortedIndex; - // Which group transform should be applied to each Gaussian. - attribute uint sortedGroupIndex; - // Buffers for splat data; each Gaussian gets 4 floats and 4 int32s. We just // copy quadjr for this. uniform usampler2D textureBuffer; @@ -87,7 +84,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // Fetch from textures. uvec4 floatBufferData = texelFetch(textureBuffer, texPos0, 0); - mat4 T_camera_group = getGroupTransform(sortedGroupIndex); + mat4 T_camera_group = getGroupTransform(floatBufferData.w); // Any early return will discard the fragment. gl_Position = vec4(0.0, 0.0, 2000.0, 1.0); @@ -197,7 +194,6 @@ export default function GlobalGaussianSplats() { const merged = mergeGaussianGroups(groupBufferFromName); const meshProps = useGaussianMeshProps( merged.gaussianBuffer, - merged.groupIndices, merged.numGroups, ); @@ -210,13 +206,6 @@ export default function GlobalGaussianSplats() { meshProps.sortedIndexAttribute.set(sortedIndices); meshProps.sortedIndexAttribute.needsUpdate = true; - // Update group indices if needed. - if (merged.numGroups >= 2) { - const sortedGroupIndices = e.data.sortedGroupIndices as Uint32Array; - meshProps.sortedGroupIndexAttribute.set(sortedGroupIndices); - meshProps.sortedGroupIndexAttribute.needsUpdate = true; - } - // Trigger initial render. if (!initializedBufferTexture) { meshProps.material.uniforms.numGaussians.value = merged.numGaussians; @@ -278,7 +267,7 @@ export default function GlobalGaussianSplats() { const T_camera_world = state.camera.matrixWorldInverse; const groupVisibles: boolean[] = []; let visibilitiesChanged = false; - for (const [sortedGroupIndex, name] of Object.keys( + for (const [groupIndex, name] of Object.keys( groupBufferFromName, ).entries()) { const node = viewer.nodeRefFromName.current[name]; @@ -292,12 +281,12 @@ export default function GlobalGaussianSplats() { colMajorElements[10], colMajorElements[14], ], - sortedGroupIndex * 4, + groupIndex * 4, ); const rowMajorElements = tmpT_camera_group.transpose().elements; meshProps.rowMajorT_camera_groups.set( rowMajorElements.slice(0, 12), - sortedGroupIndex * 12, + groupIndex * 12, ); // Determine visibility. If the parent has unmountWhenInvisible=true, the @@ -310,9 +299,9 @@ export default function GlobalGaussianSplats() { visibleNow = visibleNow && ancestor.visible; }); } - groupVisibles.push(visibleNow && prevVisibles[sortedGroupIndex] === true); - if (prevVisibles[sortedGroupIndex] !== visibleNow) { - prevVisibles[sortedGroupIndex] = visibleNow; + groupVisibles.push(visibleNow && prevVisibles[groupIndex] === true); + if (prevVisibles[groupIndex] !== visibleNow) { + prevVisibles[groupIndex] = visibleNow; visibilitiesChanged = true; } } @@ -333,7 +322,6 @@ export default function GlobalGaussianSplats() { // for the shader and not for the sorter; that way when we "show" a group // of Gaussians the correct rendering order is immediately available. for (const [i, visible] of groupVisibles.entries()) { - console.log(i, visible); if (!visible) { meshProps.rowMajorT_camera_groups[i * 12 + 3] = 1e10; meshProps.rowMajorT_camera_groups[i * 12 + 7] = 1e10; @@ -370,15 +358,25 @@ function mergeGaussianGroups(groupBufferFromName: { const groupIndices = new Uint32Array(numGaussians); let offset = 0; - for (const [sortedGroupIndex, groupBuffer] of Object.values( + for (const [groupIndex, groupBuffer] of Object.values( groupBufferFromName, ).entries()) { groupIndices.fill( - sortedGroupIndex, + groupIndex, offset / 8, (offset + groupBuffer.length) / 8, ); gaussianBuffer.set(groupBuffer, offset); + + // Each Gaussian is allocated + // - 12 bytes for center x, y, z (float32) + // - 4 bytes for group index (uint32); we're filling this in now + // + // - 12 bytes for covariance (6 terms, float16) + // - 4 bytes for RGBA (uint8) + for (let i = 0; i < groupBuffer.length; i += 8) { + gaussianBuffer[offset + i + 3] = groupIndex; + } offset += groupBuffer.length; } @@ -387,12 +385,8 @@ function mergeGaussianGroups(groupBufferFromName: { } /**Hook to generate properties for rendering Gaussians via a three.js mesh.*/ -function useGaussianMeshProps( - gaussianBuffer: Uint32Array, - groupIndices: Uint32Array, - numGroups: number, -) { - const numGaussians = groupIndices.length; +function useGaussianMeshProps(gaussianBuffer: Uint32Array, numGroups: number) { + const numGaussians = gaussianBuffer.length / 8; const maxTextureSize = useThree((state) => state.gl).capabilities .maxTextureSize; @@ -418,14 +412,6 @@ function useGaussianMeshProps( sortedIndexAttribute.setUsage(THREE.DynamicDrawUsage); geometry.setAttribute("sortedIndex", sortedIndexAttribute); - // Which group is each Gaussian in? - const sortedGroupIndexAttribute = new THREE.InstancedBufferAttribute( - groupIndices.slice(), // Copies the array. - 1, - ); - sortedGroupIndexAttribute.setUsage(THREE.DynamicDrawUsage); - geometry.setAttribute("sortedGroupIndex", sortedGroupIndexAttribute); - // Create texture buffers. const textureWidth = Math.min(numGaussians * 2, maxTextureSize); const textureHeight = Math.ceil((numGaussians * 2) / textureWidth); @@ -465,7 +451,6 @@ function useGaussianMeshProps( material, textureBuffer, sortedIndexAttribute, - sortedGroupIndexAttribute, textureT_camera_groups, rowMajorT_camera_groups, }; diff --git a/src/viser/client/src/Splatting/SplatSortWorker.ts b/src/viser/client/src/Splatting/SplatSortWorker.ts index 30e6d54b..c535f29b 100644 --- a/src/viser/client/src/Splatting/SplatSortWorker.ts +++ b/src/viser/client/src/Splatting/SplatSortWorker.ts @@ -16,16 +16,9 @@ export type SorterWorkerIncoming = { let sorter: any = null; let Tz_camera_groups: Float32Array | null = null; - let groupIndices: Uint32Array | null = null; - let sortedGroupIndices: Uint32Array | null = null; let sortRunning = false; const throttledSort = () => { - if ( - sorter === null || - Tz_camera_groups === null || - groupIndices === null || - sortedGroupIndices === null - ) { + if (sorter === null || Tz_camera_groups === null) { setTimeout(throttledSort, 1); return; } @@ -33,23 +26,16 @@ export type SorterWorkerIncoming = sortRunning = true; const lastView = Tz_camera_groups; - const sortedIndices = sorter.sort(Tz_camera_groups); - const numGroups = Tz_camera_groups.length / 4; - if (numGroups >= 2) { - // Multiple groups: we need to update the per-Gaussian group indices. - for (const [index, sortedIndex] of sortedIndices.entries()) { - sortedGroupIndices[index] = groupIndices[sortedIndex]; - } - self.postMessage({ - sortedIndices: sortedIndices, - sortedGroupIndices: sortedGroupIndices, - }); - } else { - self.postMessage({ - sortedIndices: sortedIndices, - }); - } + // Important: we clone the output so we can transfer the buffer to the main + // thread. Compared to relying on postMessage for copying, this reduces + // backlog artifacts. + const sortedIndices = ( + sorter.sort(Tz_camera_groups) as Uint32Array + ).slice(); + + // @ts-ignore + self.postMessage({ sortedIndices: sortedIndices }, [sortedIndices.buffer]); setTimeout(() => { sortRunning = false; @@ -75,8 +61,6 @@ export type SorterWorkerIncoming = data.setBuffer, data.setGroupIndices, ); - groupIndices = data.setGroupIndices; - sortedGroupIndices = groupIndices.slice(); } else if ("setTz_camera_groups" in data) { // Update object transforms. Tz_camera_groups = data.setTz_camera_groups; diff --git a/src/viser/client/src/Splatting/WasmSorter/Sorter.wasm b/src/viser/client/src/Splatting/WasmSorter/Sorter.wasm index e7d99ccb5bad1502ced6e1c7431163b14eee639c..e190e2d40759b3216cb66a4b9258b443ee8d3457 100755 GIT binary patch delta 3486 zcmb7H+lyUS8DF=3JA0owd#*V%=S*kqHA75H(k3&>$xPasvoe=tlEw-6pjI(LArI%E zcG8(3SllCsN-G%Htq8VWKxv^7TF!$$7%1X{4~5h}K#Xp98`3q$4m zN)*@lF?*RWbH-Pg=BsRqud&ZWM>(@Ro;_)o7SFcM9RJj%*CAtGrdQWd!xm@ z_^ZYJnOoS})_>%W6z8%h@q96Rk)P}xGd<25#b0OtP|@~n)`>lqj`p5E{2j0No4I#* zFFC#=@*hZE$u-yPhzO0KMH$((V*9FUYdEiREj71d%doVmS;x~dZV6S3MH>XLGvjql z2m#}3-Z9GV;HTDA+kS+f)Ix(UEZ;VfR``q78(P>8K?^n-CKoUxN|}wSz^m9wAIgGQ z2inIw6qeTgv}w5ZJDCp(adooDx!|o;8D2uG$V_v+%HUU^!%o+x^ucY~%>pfH2|VU!TPZveX0Xb#KkI$5nN@h%42XV zvaetUi$qphV3pb^#hi_?t7=`4%i62i;*v$G7&8ZifB zI1c--FpZ2%!;I8N8xWHzhiwZZAq@YJ5*l{W0AUI`M(J2&3|h3F`6AcB3Uj~r>{EO1 zjNRB&IpTaR;jv{g;`dh}sh}JwfS;uSc7dIRyFyD~0YRk;?$|wuv_wMtaKpYiQ97Zr z(rd#@S9$7FnmD=YJXWG@Uzj<;a2GaJ`}6kp46>C;@$^ zfi?R}QlUDfp`j@XRv8YLd5U0JX^^?x6rVfw-I3lQIi9O;xrDyHqr1j0{&wn{HBTSf zjT0Dqsd##R?1ZOsf#c9Px{!{MG*@@80nQPjSGDJWk~WJs=CiMR9TPl=b*0C;CL*jW zzz6@XG<39w^$4CuA45HYHb8wm*5i7V0Bt7rEEt??C#!)%T%ByVXfJ$vEYMX#@hDtT z0~ANaw-*jY$XXT>WQ(6I%%+h}2B4fEG)@$MUYKkR(*a0bBLGT0K>!?sKV!gOBK81C zfwSF{jkQnW&>9UF-CL^?s}vwj-O~G_pgHAKcdedHduiw+V7^V}$Pg$nK3+ zR&E+^waVnRJ?jseLD3_@imNsQLCdidTJ{8LU8;vDvx)JLbhr-(;d4@auhkyG7Anf_ z76@taTx+R#YjK=Nl+OV-NZOH(FF8_jfOHb1nc}x+ubdApU&EnHWs?c+3KnDyrvYwC z&kz#iK_jYpqun^=%hFC}O}NN^kLxrRf`< zp6oY-snEiVq=lowpj=#PxV3T#bxT+1-{pq(@HN$MNC;pN#-oKp4ht*gE|+rNV1<+s z#W&8K$whxSn}lh~ET;>IdM+)$sfR!7YNl{+BdDGdZfIr*|HNDT55V$?0r^N=9dkO+myfopu8yMD%jn|dx?OwwLK z?Eo%eo@#|cX#$Ja(G_fJIoKW{fxMI$><49~2M4uv<=P}rUeD<64ioa3JE~EeUPMlY z?R^>rHnS1J4xFN501yy15j@?Dk8?n$O2NV%7o!&Rk(b>3rlA7;zeGF|fo{5*0TvVJ zk+Ms}qfOjs=fVGBKgRfUL^Lx@C)C*Q?4Uksnj|n2wZyCVWYisfi`PyH4(hCL&yX8K zNWc&(uuk!k(hI5{%Aa9`7L96)5~!IYSah(?9E)G@c#_*WbPjImMxlb)f$=js|6=PLDBDA0bK=&?nA z)MQbtsc`Cu;NXlaOuMN2G_6VC6u=Bw#;EG0Fa}o;C`x@-Wy=tt=mEI9q^7 zk}-^OG>yGF36AGiX29{<$}FDuR>mh9hOJ^CAV6T=REXFvVxan9)v11ej5K92s2f} delta 3397 zcmZu!TWnlM8J?Ll=eFmp_xKvy>&?y?N5LU(8pp}nPMc;Yz9epILU}AIMO=b7D=rRk zgjCS+15XH{N&~zgM34$pAwm)R!UGW%@PNbvdE|kLA|YN7QmKU62bS{vbJjK_?0V14 znVJ7G|Ns3L-}@nd?=^nE&6mzCN$&fcF@ZnP;4JVxdJ{2|vGjb27kok%Mv)8xA?*wN zwAtqCobe5o^G!C(x7ell6laDf^G{x0UFlpnzxaWz#n$Q5*IHMfF`KRKm5tVV>#5eo z*3&1~PS*Ng=S%$ud_Cy@C4Q0rEP9&rro9yZ*glgS=ed0;`I7y9`{VtG$?wPQ`|aCU zx~+c0&)TQ*7xDZ~{%!tbe^K{2pR;ex|0>DNLsld{%f?xM@q>Tn{hyvF{eFFMUzES| z#YoA&GS5ub)A9sUd_hR_PL##k&!Su?-^ABtSNR!Flr*d(A|=c>nwUpL7Ur3f3nJ3K zJK`nSi?{QJC{?!Mtis+lpFEs1c4HR@v!A3qF)Yb_RW>%UAc#4un6oQvMKC3mFG~Tr ziV=U02B}5Qh`k+9Fxpn6ow%h%)>NYKl}I{5f;Gqloo71i=>!xyQhuQWvyYG5mR<8c zKA!SaSoXAM_@1t*0N>hnQ+eiJ*bJ{XwFfy~MRUC_Fcr75<9VE@SVfqFVI{tnH6fr@ zXDS9TOm&jygFDIL2y}XV+5Z^Yu&?A)&yE>-p z1wFVSQWyfma?b-tCq_p-Raa4O0ALOh*vvDsfTiNAbw)~#Y}I*!(*em3AO(YfRXR8- zR7j_i5|&KKr%csgJPZMChq!u}%_gZxGR2Z?9t_E#KD^cuDR~D(=8dTe1`-MqZZxaF z!mR{k z8M3GZjm#3Ggfao7&Q&reU@D<9 zTV6O-HuD;;EnpnpGai9OI&ZMuJb*1~6c@L2Sg25C#Q-^LkVpgrjGo=pUO~qA-(~gx zm(M=F49-Of(Nh$cQz~!~2l3vn4mRuTNcy8<1d;&g#wZtz4e1QW3Ox8cQqiayA!DbS z`~dZ#AasuCIruBrGYa29UHU=D?2kH&BLwB6WA?4iM`$VfuchvaTLP>OnAtlkpP!En z--5+v^O+REh5+`#NCyyM-(6=!f?(v2|*A5 zAT1z>!cha*V+fWZ6X;LHW}y@eH=*Q#EaHk%p#ofAftgz#JE9mayTl_5I0(eMq)>qu zrf?2njWDNF5;?;~Y8-~cCP)2H=|M8(szH@rWXCmMrTl@GAfKXMAf=#T&ZiBNU4=zc z^KYv0ySc<_4Oen^8B?Q{Uut~ z(siin=+^@C_OM?&s;#Id2rmfAH1%tg)l^*dYY^^3xL=_8M8^a)!2JO=ztEKiW}BB= zD@2S?#U-iJkr_lZ*>S$cYG+nt$+LPVSMJt@cX0mJVcRz|=J~i#oXXf5yvSLTgwB0c6 z{%6gA+WgO~U%~V4`WNwRZp`A@**Jser#IG*&1nJ)5SzN5i(2j1XRR* Ooh|&6(jUG0-Q+)67*C-9 diff --git a/src/viser/client/src/Splatting/WasmSorter/sorter.cpp b/src/viser/client/src/Splatting/WasmSorter/sorter.cpp index 72ed2a6c..ea9675f9 100644 --- a/src/viser/client/src/Splatting/WasmSorter/sorter.cpp +++ b/src/viser/client/src/Splatting/WasmSorter/sorter.cpp @@ -1,11 +1,11 @@ +#include +#include +#include + #include #include #include #include -#include - -#include -#include /** SIMD dot product between two 4D vectors. */ __attribute__((always_inline)) inline float @@ -41,6 +41,7 @@ __attribute__((always_inline)) inline int32_t max_i32x4(v128_t vector) { class Sorter { std::vector centers_homog; // Centers as homogeneous coordinates. std::vector group_indices; + std::vector sorted_indices; public: Sorter( @@ -51,7 +52,7 @@ class Sorter { const float *floatBuffer = reinterpret_cast(bufferVec.data()); const int32_t num_gaussians = bufferVec.size() / 8; - + sorted_indices.resize(num_gaussians); centers_homog.resize(num_gaussians); for (int32_t i = 0; i < num_gaussians; i++) { centers_homog[i] = wasm_f32x4_make( @@ -173,7 +174,6 @@ class Sorter { } // Update and return sorted indices. - std::vector sorted_indices(num_gaussians); for (int32_t i = 0; i < num_gaussians; i++) sorted_indices[starts0[((int32_t *)&gaussian_zs[0])[i]]++] = i; return emscripten::val(emscripten::typed_memory_view(