From 77c9983a7dd9c8a905ff05af910705a25ae50309 Mon Sep 17 00:00:00 2001 From: Roman Shtylman Date: Thu, 18 Jun 2026 15:09:30 -0700 Subject: [PATCH] add foxglove option to dataset viz --- src/lerobot/scripts/lerobot_dataset_viz.py | 56 +++- src/lerobot/utils/visualization_utils.py | 302 ++++++++++++++++++--- 2 files changed, 314 insertions(+), 44 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 22a7208d4..89006fd57 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -59,6 +59,18 @@ distant$ lerobot-dataset-viz \ local$ rerun rerun+http://IP:GRPC_PORT/proxy ``` +- Visualize data in Foxglove with a seekable, scrubbable timeline: +``` +local$ lerobot-dataset-viz \ + --repo-id lerobot/pusht \ + --episode-index 0 \ + --display-mode foxglove + +# then open the Foxglove app and connect to ws://127.0.0.1:8765 +``` +This starts a Foxglove WebSocket server that serves the episode on demand from the on-disk dataset, +so you can play/pause and scrub anywhere in the episode using Foxglove's playback controls. + """ import argparse @@ -150,8 +162,26 @@ def visualize_dataset( save: bool = False, output_dir: Path | None = None, display_compressed_images: bool = False, + display_mode: str = "rerun", + host: str = "127.0.0.1", + port: int = 8765, **kwargs, ) -> Path | None: + if display_mode == "foxglove": + if save: + logging.warning("--save is ignored with --display-mode foxglove (no .rrd file is written).") + from lerobot.utils.visualization_utils import serve_foxglove_dataset_playback + + logging.info("Starting Foxglove server") + serve_foxglove_dataset_playback( + dataset, + episode_index, + host=host, + port=port, + compress_images=display_compressed_images, + ) + return None + if save: assert output_dir is not None, ( "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." @@ -351,7 +381,31 @@ def main(): parser.add_argument( "--display-compressed-images", action="store_true", - help="If set, display compressed images in Rerun instead of uncompressed ones.", + help="If set, display compressed (JPEG) images instead of uncompressed ones.", + ) + + parser.add_argument( + "--display-mode", + type=str, + default="rerun", + choices=["rerun", "foxglove"], + help=( + "Visualization backend. 'rerun' uses the Rerun viewer (--mode/--save/--*-port apply). " + "'foxglove' starts a Foxglove WebSocket server that serves the episode as a seekable, " + "scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--port)." + ), + ) + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help="Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.", + ) + parser.add_argument( + "--port", + type=int, + default=8765, + help="Port to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.", ) args = parser.parse_args() diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index baeafa311..d6d477a84 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -20,7 +20,7 @@ import numpy as np from lerobot.types import RobotAction, RobotObservation -from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR +from .constants import ACTION, ACTION_PREFIX, DONE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD from .import_utils import require_package # Module-level Foxglove state. A single WebSocket server is shared for the @@ -171,11 +171,15 @@ _SCALARS_SCHEMA = { } -def _log_foxglove_scalars(topic: str, values: dict[str, float]) -> None: +def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int | None = None) -> None: """Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`. ``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of ``{label, value}`` objects. Insertion order is preserved so series stay stable across messages. + + ``log_time`` is the message time in nanoseconds. When ``None`` the server's receive time is used + (correct for live streaming); dataset playback passes the frame's dataset timestamp so the + Foxglove timeline reflects the recorded episode. """ if not values: @@ -188,7 +192,67 @@ def _log_foxglove_scalars(topic: str, values: dict[str, float]) -> None: channel = _foxglove_channels[topic] = foxglove.Channel( topic, schema=_SCALARS_SCHEMA, message_encoding="json" ) - channel.log({"scalars": [{"label": label, "value": value} for label, value in values.items()]}) + msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]} + if log_time is None: + channel.log(msg) + else: + channel.log(msg, log_time=log_time) + + +def _log_foxglove_image( + topic: str, frame_id: str, arr: np.ndarray, *, compress_images: bool, time_ns: int +) -> None: + """Log an image on a cached per-topic channel, stamped at ``time_ns`` (nanoseconds). + + ``arr`` may be HWC or CHW; CHW is transposed to HWC. ``time_ns`` sets both the message header + timestamp and the channel ``log_time`` so the message lands at the right point on the Foxglove + timeline (matching wall-clock for live streaming, or the dataset timestamp during playback). + """ + + from foxglove.channels import CompressedImageChannel, RawImageChannel + from foxglove.messages import CompressedImage, RawImage, Timestamp + + timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000) + + # Convert CHW -> HWC when needed (mirrors log_rerun_data). + if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): + arr = np.transpose(arr, (1, 2, 0)) + height, width = arr.shape[0], arr.shape[1] + channels = 1 if arr.ndim == 2 else arr.shape[2] + + if compress_images and channels == 3: + import cv2 + + # Camera frames are RGB; cv2.imencode assumes BGR, so swap to keep colors correct. + _, buf = cv2.imencode(".jpg", cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)) + channel = _foxglove_channels.get(topic) + if channel is None: + channel = _foxglove_channels[topic] = CompressedImageChannel(topic=topic) + channel.log( + CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"), + log_time=time_ns, + ) + return + + encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(channels) + if encoding is None: + return + arr = np.ascontiguousarray(arr, dtype=np.uint8) + channel = _foxglove_channels.get(topic) + if channel is None: + channel = _foxglove_channels[topic] = RawImageChannel(topic=topic) + channel.log( + RawImage( + timestamp=timestamp, + frame_id=frame_id, + width=width, + height=height, + encoding=encoding, + step=width * channels, + data=arr.tobytes(), + ), + log_time=time_ns, + ) def log_rerun_data( @@ -295,53 +359,14 @@ def log_foxglove_data( """ require_package("foxglove-sdk", extra="foxglove", import_name="foxglove") - from foxglove.channels import CompressedImageChannel, RawImageChannel - from foxglove.messages import CompressedImage, RawImage, Timestamp if _foxglove_server is None: raise RuntimeError("init_foxglove() must be called before log_foxglove_data().") now = time.time_ns() - timestamp = Timestamp(sec=now // 1_000_000_000, nsec=now % 1_000_000_000) def log_image(topic: str, frame_id: str, arr: np.ndarray) -> None: - # Convert CHW -> HWC when needed (mirrors log_rerun_data). - if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): - arr = np.transpose(arr, (1, 2, 0)) - height, width = arr.shape[0], arr.shape[1] - channels = 1 if arr.ndim == 2 else arr.shape[2] - - if compress_images and channels == 3: - import cv2 - - # Camera frames are RGB; cv2.imencode assumes BGR, so swap to keep colors correct. - _, buf = cv2.imencode(".jpg", cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)) - channel = _foxglove_channels.get(topic) - if channel is None: - channel = _foxglove_channels[topic] = CompressedImageChannel(topic=topic) - channel.log( - CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg") - ) - return - - encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(channels) - if encoding is None: - return - arr = np.ascontiguousarray(arr, dtype=np.uint8) - channel = _foxglove_channels.get(topic) - if channel is None: - channel = _foxglove_channels[topic] = RawImageChannel(topic=topic) - channel.log( - RawImage( - timestamp=timestamp, - frame_id=frame_id, - width=width, - height=height, - encoding=encoding, - step=width * channels, - data=arr.tobytes(), - ) - ) + _log_foxglove_image(topic, frame_id, arr, compress_images=compress_images, time_ns=now) if observation: obs_scalars: dict[str, float] = {} @@ -374,6 +399,197 @@ def log_foxglove_data( _log_foxglove_scalars(f"/{ACTION}/state", action_scalars) +# ── Dataset playback over a Foxglove WebSocket server ───────────────────── +# A LeRobotDataset is random-access on disk, so rather than fire-and-forget a forward stream we +# advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays +# to in the Foxglove app. This relies on the SDK's PlaybackControl capability. + +_SUCCESS = "next.success" + + +def _frame_to_scalars(sample: dict, key: str) -> dict[str, float]: + """Flatten a frame's vector/scalar feature ``key`` into ``{label: value}`` entries. + + Vectors are expanded to ```` labels (one series per dimension); a scalar becomes a single + entry. Missing or ``None`` features yield an empty mapping. + """ + + v = sample.get(key) + if v is None: + return {} + arr = v.numpy() if hasattr(v, "numpy") else np.asarray(v) + if arr.ndim == 0: + return {"0": float(arr)} + return {str(i): float(x) for i, x in enumerate(arr.flatten())} + + +def serve_foxglove_dataset_playback( + dataset, + episode_index: int, + *, + host: str = "127.0.0.1", + port: int = 8765, + compress_images: bool = False, +) -> None: + """Serve a single dataset episode to Foxglove as a seekable, scrubbable timeline. + + Starts a Foxglove WebSocket server advertising the ``PlaybackControl`` capability over the + episode's time range. The Foxglove app drives play/pause/seek/speed; a background thread and a + ``ServerListener`` read frames from the on-disk ``dataset`` on demand and log them stamped at + their dataset timestamps, so the user can scrub anywhere in the episode. Blocks until interrupted. + + Args: + dataset: A ``LeRobotDataset`` loaded for the single episode to visualize. + episode_index: Index of the episode being visualized (used only for the session name). + host: Host interface to bind the WebSocket server to. + port: Port to bind the WebSocket server to. + compress_images: Whether to JPEG-compress camera frames before logging. + """ + + require_package("foxglove-sdk", extra="foxglove", import_name="foxglove") + import bisect + import threading + + import foxglove + from foxglove.websocket import ( + Capability, + PlaybackCommand, + PlaybackControlRequest, + PlaybackState, + PlaybackStatus, + ServerListener, + ) + + # Per-frame timestamps in nanoseconds (read straight from the table, no video decode). + times_ns = [int(round(float(t) * 1e9)) for t in dataset.hf_dataset["timestamp"]] + n_frames = len(times_ns) + if n_frames == 0: + raise ValueError("Cannot visualize an empty episode.") + first_ns, last_ns = times_ns[0], times_ns[-1] + camera_keys = list(dataset.meta.camera_keys) + + def topic_for(key: str) -> str: + name = key[len(OBS_PREFIX) :] if str(key).startswith(OBS_PREFIX) else str(key) + return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}" + + def emit_frame(i: int) -> None: + """Log every channel for frame ``i`` stamped at its dataset timestamp.""" + sample = dataset[i] + log_time = times_ns[i] + for key in camera_keys: + arr = sample.get(key) + if arr is None: + continue + arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr) + if np.issubdtype(arr.dtype, np.floating): + arr = (arr * 255.0).clip(0, 255).astype(np.uint8) + _log_foxglove_image(topic_for(key), key, arr, compress_images=compress_images, time_ns=log_time) + _log_foxglove_scalars(f"/{OBS_STR}/state", _frame_to_scalars(sample, OBS_STATE), log_time=log_time) + _log_foxglove_scalars(f"/{ACTION}/state", _frame_to_scalars(sample, ACTION), log_time=log_time) + episode_scalars = {} + for feat, label in ((DONE, "done"), (REWARD, "reward"), (_SUCCESS, "success")): + v = sample.get(feat) + if v is not None: + episode_scalars[label] = float(v) + _log_foxglove_scalars("/episode/state", episode_scalars, log_time=log_time) + + lock = threading.Lock() + stop_event = threading.Event() + server_holder: dict = {} + # Shared playback state, guarded by ``lock``. + state = {"status": PlaybackStatus.Paused, "cursor": first_ns, "speed": 1.0, "last_idx": -1} + + def index_at(t_ns: int) -> int: + return max(0, min(n_frames - 1, bisect.bisect_right(times_ns, t_ns) - 1)) + + class _PlaybackListener(ServerListener): + def on_playback_control_request(self, req: PlaybackControlRequest): + emit_idx = None + with lock: + did_seek = False + if req.seek_time is not None: + cursor = max(first_ns, min(last_ns, req.seek_time)) + state["cursor"] = cursor + emit_idx = state["last_idx"] = index_at(cursor) + did_seek = True + if req.playback_speed and req.playback_speed > 0: + state["speed"] = req.playback_speed + if req.playback_command == PlaybackCommand.Play: + # Restarting from the end replays from the beginning. + if state["cursor"] >= last_ns: + state["cursor"] = first_ns + emit_idx = state["last_idx"] = 0 + did_seek = True + state["status"] = PlaybackStatus.Playing + elif req.playback_command == PlaybackCommand.Pause: + state["status"] = PlaybackStatus.Paused + status, cursor, speed = state["status"], state["cursor"], state["speed"] + request_id = req.request_id or "" + if emit_idx is not None: + emit_frame(emit_idx) + return PlaybackState(status, cursor, speed, did_seek, request_id) + + server = foxglove.start_server( + name=f"{dataset.repo_id}/episode_{episode_index}", + host=host, + port=port, + capabilities=[Capability.PlaybackControl, Capability.Time], + server_listener=_PlaybackListener(), + playback_time_range=(first_ns, last_ns), + ) + server_holder["server"] = server + + def playback_loop() -> None: + prev = time.monotonic() + while not stop_event.is_set(): + time.sleep(1.0 / 60.0) + with lock: + now = time.monotonic() + dt = now - prev + prev = now + if state["status"] != PlaybackStatus.Playing: + continue + cursor = state["cursor"] + int(dt * 1e9 * state["speed"]) + start_idx = state["last_idx"] + 1 + if cursor >= last_ns: + cursor, target, ended = last_ns, n_frames - 1, True + else: + target, ended = index_at(cursor), False + state["cursor"] = cursor + state["last_idx"] = max(state["last_idx"], target) + if ended: + state["status"] = PlaybackStatus.Ended + speed = state["speed"] + for i in range(start_idx, target + 1): + emit_frame(i) + server.broadcast_time(cursor) + if ended: + server.broadcast_playback_state(PlaybackState(PlaybackStatus.Ended, cursor, speed, False, "")) + + thread = threading.Thread(target=playback_loop, name="foxglove-playback", daemon=True) + thread.start() + + # Emit the first frame so channels are advertised and the viewer isn't blank before playback. + emit_frame(0) + with lock: + state["last_idx"] = 0 + server.broadcast_time(first_ns) + server.broadcast_playback_state(PlaybackState(PlaybackStatus.Paused, first_ns, 1.0, True, "")) + + print(f"Foxglove server running. Connect the Foxglove app to ws://{host}:{port}") + print("Use the playback controls in Foxglove to play/pause and scrub the episode. Ctrl-C to exit.") + try: + while not stop_event.is_set(): + time.sleep(0.5) + except KeyboardInterrupt: + print("Ctrl-C received. Exiting.") + finally: + stop_event.set() + thread.join(timeout=2.0) + server.stop() + _foxglove_channels.clear() + + # ── Backend-agnostic dispatch ───────────────────────────────────────────── # These let callers select a visualization backend at runtime via a string # (e.g. a `--display_mode` CLI flag) without branching on the backend everywhere.