mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 17:17:01 +00:00
chore(viz): minor improvements
This commit is contained in:
@@ -85,7 +85,7 @@ import torch.utils.data
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD, SUCCESS
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ def build_blueprint_from_dataset(dataset: LeRobotDataset):
|
||||
names = get_feature_names(dataset, key)
|
||||
styling = rr.SeriesLines(names=names)
|
||||
views.append(rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling}))
|
||||
for key in (DONE, REWARD, "next.success"):
|
||||
for key in (DONE, REWARD, SUCCESS):
|
||||
if key in dataset.features:
|
||||
views.append(rrb.TimeSeriesView(origin=key, name=key))
|
||||
|
||||
@@ -270,8 +270,8 @@ def visualize_dataset(
|
||||
if REWARD in batch:
|
||||
rr.log(REWARD, rr.Scalars(batch[REWARD][i].item()))
|
||||
|
||||
if "next.success" in batch:
|
||||
rr.log("next.success", rr.Scalars(batch["next.success"][i].item()))
|
||||
if SUCCESS in batch:
|
||||
rr.log(SUCCESS, rr.Scalars(batch[SUCCESS][i].item()))
|
||||
|
||||
# save .rrd locally
|
||||
if mode == "local" and save:
|
||||
|
||||
@@ -32,18 +32,13 @@ from .constants import (
|
||||
OBS_STR,
|
||||
REWARD,
|
||||
SUCCESS,
|
||||
TRUNCATED,
|
||||
)
|
||||
from .import_utils import require_package
|
||||
|
||||
# Visualization backends selectable at runtime via a display-mode string (e.g. a --display_mode flag).
|
||||
VISUALIZATION_MODES = ("rerun", "foxglove")
|
||||
|
||||
# Module-level Foxglove state. A single WebSocket server is shared for the
|
||||
# process lifetime, and image channels are cached by topic (the Foxglove SDK
|
||||
# requires reusing one channel per topic).
|
||||
_foxglove_server = None
|
||||
_foxglove_channels: dict = {}
|
||||
|
||||
# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}``
|
||||
# pairs rather than one field per feature, so the same schema fits any robot regardless of which
|
||||
# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name
|
||||
@@ -118,20 +113,23 @@ def init_foxglove(host: str = "127.0.0.1", port: int | None = 8765) -> None:
|
||||
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
|
||||
import foxglove
|
||||
|
||||
global _foxglove_server
|
||||
if _foxglove_server is not None:
|
||||
# Live-stream state lives as attributes on ``log_foxglove_data``:
|
||||
# ``.server`` is the shared WebSocket server and
|
||||
# ``.channels`` caches one Foxglove channel per topic
|
||||
if getattr(log_foxglove_data, "server", None) is not None:
|
||||
return
|
||||
_foxglove_server = foxglove.start_server(host=host, port=port or 8765)
|
||||
log_foxglove_data.server = foxglove.start_server(host=host, port=port or 8765)
|
||||
log_foxglove_data.channels = {}
|
||||
|
||||
|
||||
def shutdown_foxglove() -> None:
|
||||
"""Stops the Foxglove WebSocket server and clears cached channels."""
|
||||
|
||||
global _foxglove_server
|
||||
if _foxglove_server is not None:
|
||||
_foxglove_server.stop()
|
||||
_foxglove_server = None
|
||||
_foxglove_channels.clear()
|
||||
server = getattr(log_foxglove_data, "server", None)
|
||||
if server is not None:
|
||||
server.stop()
|
||||
log_foxglove_data.server = None
|
||||
log_foxglove_data.channels = {}
|
||||
|
||||
|
||||
def _is_scalar(x):
|
||||
@@ -223,7 +221,7 @@ def _log_foxglove_scalars(
|
||||
import foxglove
|
||||
|
||||
if channels is None:
|
||||
channels = _foxglove_channels
|
||||
channels = log_foxglove_data.channels
|
||||
channel = channels.get(topic)
|
||||
if channel is None:
|
||||
channel = channels[topic] = foxglove.Channel(topic, schema=_SCALARS_SCHEMA, message_encoding="json")
|
||||
@@ -266,7 +264,7 @@ def _log_foxglove_image(
|
||||
from foxglove.messages import CompressedImage, RawImage, Timestamp
|
||||
|
||||
if channels is None:
|
||||
channels = _foxglove_channels
|
||||
channels = log_foxglove_data.channels
|
||||
time_ns = time.time_ns() if log_time is None else log_time
|
||||
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
|
||||
log_kwargs = {} if log_time is None else {"log_time": log_time}
|
||||
@@ -424,7 +422,7 @@ def log_foxglove_data(
|
||||
|
||||
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
|
||||
|
||||
if _foxglove_server is None:
|
||||
if getattr(log_foxglove_data, "server", None) is None:
|
||||
raise RuntimeError("init_foxglove() must be called before log_foxglove_data().")
|
||||
|
||||
now = time.time_ns()
|
||||
@@ -603,7 +601,12 @@ def serve_foxglove_dataset_playback(
|
||||
log_time=log_time,
|
||||
)
|
||||
episode_scalars = {}
|
||||
for feat, label in ((DONE, "done"), (REWARD, "reward"), (SUCCESS, "success")):
|
||||
for feat, label in (
|
||||
(DONE, "done"),
|
||||
(TRUNCATED, "truncated"),
|
||||
(REWARD, "reward"),
|
||||
(SUCCESS, "success"),
|
||||
):
|
||||
v = sample.get(feat)
|
||||
if v is not None:
|
||||
episode_scalars[label] = float(v)
|
||||
@@ -660,6 +663,10 @@ def serve_foxglove_dataset_playback(
|
||||
)
|
||||
|
||||
def playback_loop() -> None:
|
||||
# Cap how far the cursor may advance in a single tick. A slow frame decode (or any stall)
|
||||
# would otherwise make ``dt`` huge and produce one enormous catch-up batch; clamping it makes
|
||||
# playback trail wall-clock under a slow decoder while each tick emits a bounded frame range.
|
||||
max_tick_dt_s = 0.25
|
||||
prev = time.monotonic()
|
||||
while not stop_event.is_set():
|
||||
time.sleep(1.0 / 60.0)
|
||||
@@ -667,7 +674,7 @@ def serve_foxglove_dataset_playback(
|
||||
speed = 1.0
|
||||
with lock:
|
||||
now = time.monotonic()
|
||||
dt = now - prev
|
||||
dt = min(now - prev, max_tick_dt_s)
|
||||
prev = now
|
||||
# A queued seek is always serviced, even while paused, so scrubbing updates the view.
|
||||
work = []
|
||||
@@ -684,14 +691,19 @@ def serve_foxglove_dataset_playback(
|
||||
target = index_at(cursor)
|
||||
state["cursor"] = cursor
|
||||
work.extend(range(start_idx, target + 1))
|
||||
state["last_idx"] = max(state["last_idx"], target)
|
||||
# cursor only grows while playing (seeks reset last_idx in the listener), so
|
||||
# target >= last_idx here; a plain assignment is correct and clearer than max().
|
||||
state["last_idx"] = target
|
||||
if ended:
|
||||
state["status"] = PlaybackStatus.Ended
|
||||
if not work:
|
||||
continue
|
||||
cursor, speed = state["cursor"], state["speed"]
|
||||
# Emit outside the lock; this is the only thread that calls emit_frame.
|
||||
# Emit outside the lock; this is the only thread that calls emit_frame. Re-check
|
||||
# stop_event between frames so shutdown stays responsive even mid-batch.
|
||||
for i in work:
|
||||
if stop_event.is_set():
|
||||
break
|
||||
emit_frame(i)
|
||||
server.broadcast_time(cursor)
|
||||
if ended:
|
||||
|
||||
Reference in New Issue
Block a user