chore(viz): minor improvements

This commit is contained in:
Steven Palma
2026-07-01 13:16:31 +02:00
parent 15678219c6
commit a33b165b12
2 changed files with 37 additions and 25 deletions
+4 -4
View File
@@ -85,7 +85,7 @@ import torch.utils.data
import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD, SUCCESS
from lerobot.utils.utils import init_logging
@@ -138,7 +138,7 @@ def build_blueprint_from_dataset(dataset: LeRobotDataset):
names = get_feature_names(dataset, key)
styling = rr.SeriesLines(names=names)
views.append(rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling}))
for key in (DONE, REWARD, "next.success"):
for key in (DONE, REWARD, SUCCESS):
if key in dataset.features:
views.append(rrb.TimeSeriesView(origin=key, name=key))
@@ -270,8 +270,8 @@ def visualize_dataset(
if REWARD in batch:
rr.log(REWARD, rr.Scalars(batch[REWARD][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalars(batch["next.success"][i].item()))
if SUCCESS in batch:
rr.log(SUCCESS, rr.Scalars(batch[SUCCESS][i].item()))
# save .rrd locally
if mode == "local" and save:
+33 -21
View File
@@ -32,18 +32,13 @@ from .constants import (
OBS_STR,
REWARD,
SUCCESS,
TRUNCATED,
)
from .import_utils import require_package
# Visualization backends selectable at runtime via a display-mode string (e.g. a --display_mode flag).
VISUALIZATION_MODES = ("rerun", "foxglove")
# Module-level Foxglove state. A single WebSocket server is shared for the
# process lifetime, and image channels are cached by topic (the Foxglove SDK
# requires reusing one channel per topic).
_foxglove_server = None
_foxglove_channels: dict = {}
# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}``
# pairs rather than one field per feature, so the same schema fits any robot regardless of which
# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name
@@ -118,20 +113,23 @@ def init_foxglove(host: str = "127.0.0.1", port: int | None = 8765) -> None:
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
import foxglove
global _foxglove_server
if _foxglove_server is not None:
# Live-stream state lives as attributes on ``log_foxglove_data``:
# ``.server`` is the shared WebSocket server and
# ``.channels`` caches one Foxglove channel per topic
if getattr(log_foxglove_data, "server", None) is not None:
return
_foxglove_server = foxglove.start_server(host=host, port=port or 8765)
log_foxglove_data.server = foxglove.start_server(host=host, port=port or 8765)
log_foxglove_data.channels = {}
def shutdown_foxglove() -> None:
"""Stops the Foxglove WebSocket server and clears cached channels."""
global _foxglove_server
if _foxglove_server is not None:
_foxglove_server.stop()
_foxglove_server = None
_foxglove_channels.clear()
server = getattr(log_foxglove_data, "server", None)
if server is not None:
server.stop()
log_foxglove_data.server = None
log_foxglove_data.channels = {}
def _is_scalar(x):
@@ -223,7 +221,7 @@ def _log_foxglove_scalars(
import foxglove
if channels is None:
channels = _foxglove_channels
channels = log_foxglove_data.channels
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = foxglove.Channel(topic, schema=_SCALARS_SCHEMA, message_encoding="json")
@@ -266,7 +264,7 @@ def _log_foxglove_image(
from foxglove.messages import CompressedImage, RawImage, Timestamp
if channels is None:
channels = _foxglove_channels
channels = log_foxglove_data.channels
time_ns = time.time_ns() if log_time is None else log_time
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
log_kwargs = {} if log_time is None else {"log_time": log_time}
@@ -424,7 +422,7 @@ def log_foxglove_data(
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
if _foxglove_server is None:
if getattr(log_foxglove_data, "server", None) is None:
raise RuntimeError("init_foxglove() must be called before log_foxglove_data().")
now = time.time_ns()
@@ -603,7 +601,12 @@ def serve_foxglove_dataset_playback(
log_time=log_time,
)
episode_scalars = {}
for feat, label in ((DONE, "done"), (REWARD, "reward"), (SUCCESS, "success")):
for feat, label in (
(DONE, "done"),
(TRUNCATED, "truncated"),
(REWARD, "reward"),
(SUCCESS, "success"),
):
v = sample.get(feat)
if v is not None:
episode_scalars[label] = float(v)
@@ -660,6 +663,10 @@ def serve_foxglove_dataset_playback(
)
def playback_loop() -> None:
# Cap how far the cursor may advance in a single tick. A slow frame decode (or any stall)
# would otherwise make ``dt`` huge and produce one enormous catch-up batch; clamping it makes
# playback trail wall-clock under a slow decoder while each tick emits a bounded frame range.
max_tick_dt_s = 0.25
prev = time.monotonic()
while not stop_event.is_set():
time.sleep(1.0 / 60.0)
@@ -667,7 +674,7 @@ def serve_foxglove_dataset_playback(
speed = 1.0
with lock:
now = time.monotonic()
dt = now - prev
dt = min(now - prev, max_tick_dt_s)
prev = now
# A queued seek is always serviced, even while paused, so scrubbing updates the view.
work = []
@@ -684,14 +691,19 @@ def serve_foxglove_dataset_playback(
target = index_at(cursor)
state["cursor"] = cursor
work.extend(range(start_idx, target + 1))
state["last_idx"] = max(state["last_idx"], target)
# cursor only grows while playing (seeks reset last_idx in the listener), so
# target >= last_idx here; a plain assignment is correct and clearer than max().
state["last_idx"] = target
if ended:
state["status"] = PlaybackStatus.Ended
if not work:
continue
cursor, speed = state["cursor"], state["speed"]
# Emit outside the lock; this is the only thread that calls emit_frame.
# Emit outside the lock; this is the only thread that calls emit_frame. Re-check
# stop_event between frames so shutdown stays responsive even mid-batch.
for i in work:
if stop_event.is_set():
break
emit_frame(i)
server.broadcast_time(cursor)
if ended: