Skip to content

Commit

Permalink
New share URL API (nerfstudio-project#139)
Browse files Browse the repository at this point in the history
* Add API for accessing host, port, share URL after instantiation

* Deprecate old share URL api

* get_share_url() => request_share_url()

* Bump version

* Update Record3D example
  • Loading branch information
brentyi authored Nov 30, 2023
1 parent da51044 commit 8b3fa9d
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 36 deletions.
4 changes: 3 additions & 1 deletion examples/07_record3d_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def main(
max_frames: int = 100,
share: bool = False,
) -> None:
server = viser.ViserServer(share=share)
server = viser.ViserServer()
if share:
server.request_share_url()

print("Loading frames!")
loader = viser.extras.Record3dLoader(data_path)
Expand Down
5 changes: 4 additions & 1 deletion examples/08_smplx_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def main(
ext: Literal["npz", "pkl"] = "npz",
share: bool = False,
) -> None:
server = viser.ViserServer(share=share)
server = viser.ViserServer()
if share:
server.request_share_url()

server.configure_theme(control_layout="collapsible")
model = smplx.create(
model_path=str(model_path),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "viser"
version = "0.1.12"
version = "0.1.13"
description = "3D visualization + Python"
readme = "README.md"
license = { text="MIT" }
Expand Down
10 changes: 9 additions & 1 deletion src/viser/_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import time
from multiprocessing.managers import DictProxy
from typing import Callable, Optional
from typing import Callable, Literal, Optional

import requests

Expand Down Expand Up @@ -46,6 +46,11 @@ def get_url(self) -> Optional[str]:
"""Get tunnel URL. None if not connected (or connection failed)."""
return self._shared_state["url"]

def get_status(
self,
) -> Literal["ready", "connecting", "failed", "connected", "closed"]:
return self._shared_state["status"]

def close(self) -> None:
"""Close the tunnel."""
if self._process is not None:
Expand Down Expand Up @@ -104,6 +109,9 @@ async def _make_tunnel(local_port: int, shared_state: DictProxy) -> None:
]
)

shared_state["url"] = None
shared_state["status"] = "closed"


async def _simple_proxy(
local_host: str,
Expand Down
128 changes: 96 additions & 32 deletions src/viser/_viser.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,19 @@ class ViserServer(MessageApi, GuiApi):
Args:
host: Host to bind server to.
port: Port to bind server to.
share: Experimental. If set to `True`, create and print a public, shareable URL
for this instance of viser.
"""

world_axes: FrameHandle
"""Handle for manipulating the world frame axes (/WorldAxes), which is instantiated
and then hidden by default."""

def __init__(self, host: str = "0.0.0.0", port: int = 8080, share: bool = False):
# Hide deprecated arguments from docstring and type checkers.
def __init__(self, host: str = "0.0.0.0", port: int = 8080):
...

def _actual_init(
self, host: str = "0.0.0.0", port: int = 8080, **_deprecated_kwargs
):
server = infra.Server(
host=host,
port=port,
Expand Down Expand Up @@ -414,26 +418,15 @@ def _(conn: infra.ClientConnection) -> None:
)
table.add_row("HTTP", http_url)
table.add_row("Websocket", ws_url)
rich.print(Panel(table, title="[bold]viser[/bold]", expand=False))

# Create share tunnel if requested.
if not share:
self._share_tunnel = None
rich.print(Panel(table, title="[bold]viser[/bold]", expand=False))
else:
rich.print(
"[bold](viser)[/bold] Share URL requested! (expires in 24 hours)"
)
self._share_tunnel = _ViserTunnel(port)
self._share_tunnel: Optional[_ViserTunnel] = None

@self._share_tunnel.on_connect
def _() -> None:
assert self._share_tunnel is not None
share_url = self._share_tunnel.get_url()
if share_url is None:
rich.print("[bold](viser)[/bold] Could not generate share URL")
else:
table.add_row("Share URL", share_url)
rich.print(Panel(table, title="[bold]viser[/bold]", expand=False))
# Create share tunnel if requested.
# This is deprecated: we should use get_share_url() instead.
share = _deprecated_kwargs.get("share", False)
if share:
self.request_share_url()

self.reset_scene()
self.world_axes = FrameHandle(
Expand All @@ -446,25 +439,80 @@ def _() -> None:
)
self.world_axes.visible = False

def get_host(self) -> str:
"""Returns the host address of the Viser server.
Returns:
Host address as string.
"""
return self._server._host

def get_port(self) -> int:
"""Returns the port of the Viser server. This could be different from the
originally requested one.
Returns:
Port as integer.
"""
return self._server._port

def request_share_url(self, verbose: bool = True) -> Optional[str]:
"""Request a share URL for the Viser server, which allows for public access.
On the first call, will block until a connecting with the share URL server is
established. Afterwards, the URL will be returned directly.
This is an experimental feature that relies on an external server; it shouldn't
be relied on for critical applications.
Returns:
Share URL as string, or None if connection fails or is closed.
"""

if self._share_tunnel is not None:
# Tunnel already exists.
while self._share_tunnel.get_status() in ("ready", "connecting"):
time.sleep(0.05)
return self._share_tunnel.get_url()
else:
# Create a new tunnel!.
if verbose:
rich.print(
"[bold](viser)[/bold] Share URL requested! (expires in 24 hours)"
)

connect_event = threading.Event()

self._share_tunnel = _ViserTunnel(self._server._port)

@self._share_tunnel.on_connect
def _() -> None:
assert self._share_tunnel is not None
if verbose:
share_url = self._share_tunnel.get_url()
if share_url is None:
rich.print("[bold](viser)[/bold] Could not generate share URL")
else:
rich.print(
f"[bold](viser)[/bold] Generated share URL: {share_url}"
)
connect_event.set()

connect_event.wait()
return self._share_tunnel.get_url()

def stop(self) -> None:
"""Stop the Viser server and associated threads and tunnels."""
self._server.stop()
if self._share_tunnel is not None:
self._share_tunnel.close()

@override
def _get_api(self) -> MessageApi:
"""Message API to use."""
return self

@override
def _queue_unsafe(self, message: _messages.Message) -> None:
"""Define how the message API should send messages."""
self._server.broadcast(message)

def get_clients(self) -> Dict[int, ClientHandle]:
"""Creates and returns a copy of the mapping from connected client IDs to
handles."""
handles.
Returns:
Dictionary of clients.
"""
with self._state.client_lock:
return self._state.connected_clients.copy()

Expand Down Expand Up @@ -502,6 +550,9 @@ def atomic(self) -> Generator[None, None, None]:
This can be helpful for things like animations, or when we want position and
orientation updates to happen synchronously.
Returns:
Context manager.
"""
# Acquire the global atomic lock.
# If called multiple times in the same thread, we ignore inner calls.
Expand Down Expand Up @@ -530,3 +581,16 @@ def flush(self) -> None:
"""Flush the outgoing message buffer. Any buffered messages will immediately be
sent. (by default they are windowed)"""
self._server.flush()

@override
def _get_api(self) -> MessageApi:
"""Message API to use."""
return self

@override
def _queue_unsafe(self, message: _messages.Message) -> None:
"""Define how the message API should send messages."""
self._server.broadcast(message)


ViserServer.__init__ = ViserServer._actual_init # type: ignore

0 comments on commit 8b3fa9d

Please sign in to comment.