add foxglove option to dataset viz

This commit is contained in:
Roman Shtylman
2026-06-18 15:09:30 -07:00
committed by CarolinePascal
parent 23817e0ab0
commit 77c9983a7d
2 changed files with 314 additions and 44 deletions
+55 -1
View File
@@ -59,6 +59,18 @@ distant$ lerobot-dataset-viz \
local$ rerun rerun+http://IP:GRPC_PORT/proxy
```
- Visualize data in Foxglove with a seekable, scrubbable timeline:
```
local$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--display-mode foxglove
# then open the Foxglove app and connect to ws://127.0.0.1:8765
```
This starts a Foxglove WebSocket server that serves the episode on demand from the on-disk dataset,
so you can play/pause and scrub anywhere in the episode using Foxglove's playback controls.
"""
import argparse
@@ -150,8 +162,26 @@ def visualize_dataset(
save: bool = False,
output_dir: Path | None = None,
display_compressed_images: bool = False,
display_mode: str = "rerun",
host: str = "127.0.0.1",
port: int = 8765,
**kwargs,
) -> Path | None:
if display_mode == "foxglove":
if save:
logging.warning("--save is ignored with --display-mode foxglove (no .rrd file is written).")
from lerobot.utils.visualization_utils import serve_foxglove_dataset_playback
logging.info("Starting Foxglove server")
serve_foxglove_dataset_playback(
dataset,
episode_index,
host=host,
port=port,
compress_images=display_compressed_images,
)
return None
if save:
assert output_dir is not None, (
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
@@ -351,7 +381,31 @@ def main():
parser.add_argument(
"--display-compressed-images",
action="store_true",
help="If set, display compressed images in Rerun instead of uncompressed ones.",
help="If set, display compressed (JPEG) images instead of uncompressed ones.",
)
parser.add_argument(
"--display-mode",
type=str,
default="rerun",
choices=["rerun", "foxglove"],
help=(
"Visualization backend. 'rerun' uses the Rerun viewer (--mode/--save/--*-port apply). "
"'foxglove' starts a Foxglove WebSocket server that serves the episode as a seekable, "
"scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--port)."
),
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.",
)
parser.add_argument(
"--port",
type=int,
default=8765,
help="Port to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.",
)
args = parser.parse_args()
+259 -43
View File
@@ -20,7 +20,7 @@ import numpy as np
from lerobot.types import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
from .constants import ACTION, ACTION_PREFIX, DONE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
from .import_utils import require_package
# Module-level Foxglove state. A single WebSocket server is shared for the
@@ -171,11 +171,15 @@ _SCALARS_SCHEMA = {
}
def _log_foxglove_scalars(topic: str, values: dict[str, float]) -> None:
def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int | None = None) -> None:
"""Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`.
``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of
``{label, value}`` objects. Insertion order is preserved so series stay stable across messages.
``log_time`` is the message time in nanoseconds. When ``None`` the server's receive time is used
(correct for live streaming); dataset playback passes the frame's dataset timestamp so the
Foxglove timeline reflects the recorded episode.
"""
if not values:
@@ -188,7 +192,67 @@ def _log_foxglove_scalars(topic: str, values: dict[str, float]) -> None:
channel = _foxglove_channels[topic] = foxglove.Channel(
topic, schema=_SCALARS_SCHEMA, message_encoding="json"
)
channel.log({"scalars": [{"label": label, "value": value} for label, value in values.items()]})
msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]}
if log_time is None:
channel.log(msg)
else:
channel.log(msg, log_time=log_time)
def _log_foxglove_image(
topic: str, frame_id: str, arr: np.ndarray, *, compress_images: bool, time_ns: int
) -> None:
"""Log an image on a cached per-topic channel, stamped at ``time_ns`` (nanoseconds).
``arr`` may be HWC or CHW; CHW is transposed to HWC. ``time_ns`` sets both the message header
timestamp and the channel ``log_time`` so the message lands at the right point on the Foxglove
timeline (matching wall-clock for live streaming, or the dataset timestamp during playback).
"""
from foxglove.channels import CompressedImageChannel, RawImageChannel
from foxglove.messages import CompressedImage, RawImage, Timestamp
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
height, width = arr.shape[0], arr.shape[1]
channels = 1 if arr.ndim == 2 else arr.shape[2]
if compress_images and channels == 3:
import cv2
# Camera frames are RGB; cv2.imencode assumes BGR, so swap to keep colors correct.
_, buf = cv2.imencode(".jpg", cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
channel = _foxglove_channels.get(topic)
if channel is None:
channel = _foxglove_channels[topic] = CompressedImageChannel(topic=topic)
channel.log(
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"),
log_time=time_ns,
)
return
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(channels)
if encoding is None:
return
arr = np.ascontiguousarray(arr, dtype=np.uint8)
channel = _foxglove_channels.get(topic)
if channel is None:
channel = _foxglove_channels[topic] = RawImageChannel(topic=topic)
channel.log(
RawImage(
timestamp=timestamp,
frame_id=frame_id,
width=width,
height=height,
encoding=encoding,
step=width * channels,
data=arr.tobytes(),
),
log_time=time_ns,
)
def log_rerun_data(
@@ -295,53 +359,14 @@ def log_foxglove_data(
"""
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
from foxglove.channels import CompressedImageChannel, RawImageChannel
from foxglove.messages import CompressedImage, RawImage, Timestamp
if _foxglove_server is None:
raise RuntimeError("init_foxglove() must be called before log_foxglove_data().")
now = time.time_ns()
timestamp = Timestamp(sec=now // 1_000_000_000, nsec=now % 1_000_000_000)
def log_image(topic: str, frame_id: str, arr: np.ndarray) -> None:
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
height, width = arr.shape[0], arr.shape[1]
channels = 1 if arr.ndim == 2 else arr.shape[2]
if compress_images and channels == 3:
import cv2
# Camera frames are RGB; cv2.imencode assumes BGR, so swap to keep colors correct.
_, buf = cv2.imencode(".jpg", cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
channel = _foxglove_channels.get(topic)
if channel is None:
channel = _foxglove_channels[topic] = CompressedImageChannel(topic=topic)
channel.log(
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg")
)
return
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(channels)
if encoding is None:
return
arr = np.ascontiguousarray(arr, dtype=np.uint8)
channel = _foxglove_channels.get(topic)
if channel is None:
channel = _foxglove_channels[topic] = RawImageChannel(topic=topic)
channel.log(
RawImage(
timestamp=timestamp,
frame_id=frame_id,
width=width,
height=height,
encoding=encoding,
step=width * channels,
data=arr.tobytes(),
)
)
_log_foxglove_image(topic, frame_id, arr, compress_images=compress_images, time_ns=now)
if observation:
obs_scalars: dict[str, float] = {}
@@ -374,6 +399,197 @@ def log_foxglove_data(
_log_foxglove_scalars(f"/{ACTION}/state", action_scalars)
# ── Dataset playback over a Foxglove WebSocket server ─────────────────────
# A LeRobotDataset is random-access on disk, so rather than fire-and-forget a forward stream we
# advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays
# to in the Foxglove app. This relies on the SDK's PlaybackControl capability.
_SUCCESS = "next.success"
def _frame_to_scalars(sample: dict, key: str) -> dict[str, float]:
"""Flatten a frame's vector/scalar feature ``key`` into ``{label: value}`` entries.
Vectors are expanded to ``<i>`` labels (one series per dimension); a scalar becomes a single
entry. Missing or ``None`` features yield an empty mapping.
"""
v = sample.get(key)
if v is None:
return {}
arr = v.numpy() if hasattr(v, "numpy") else np.asarray(v)
if arr.ndim == 0:
return {"0": float(arr)}
return {str(i): float(x) for i, x in enumerate(arr.flatten())}
def serve_foxglove_dataset_playback(
dataset,
episode_index: int,
*,
host: str = "127.0.0.1",
port: int = 8765,
compress_images: bool = False,
) -> None:
"""Serve a single dataset episode to Foxglove as a seekable, scrubbable timeline.
Starts a Foxglove WebSocket server advertising the ``PlaybackControl`` capability over the
episode's time range. The Foxglove app drives play/pause/seek/speed; a background thread and a
``ServerListener`` read frames from the on-disk ``dataset`` on demand and log them stamped at
their dataset timestamps, so the user can scrub anywhere in the episode. Blocks until interrupted.
Args:
dataset: A ``LeRobotDataset`` loaded for the single episode to visualize.
episode_index: Index of the episode being visualized (used only for the session name).
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to.
compress_images: Whether to JPEG-compress camera frames before logging.
"""
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
import bisect
import threading
import foxglove
from foxglove.websocket import (
Capability,
PlaybackCommand,
PlaybackControlRequest,
PlaybackState,
PlaybackStatus,
ServerListener,
)
# Per-frame timestamps in nanoseconds (read straight from the table, no video decode).
times_ns = [int(round(float(t) * 1e9)) for t in dataset.hf_dataset["timestamp"]]
n_frames = len(times_ns)
if n_frames == 0:
raise ValueError("Cannot visualize an empty episode.")
first_ns, last_ns = times_ns[0], times_ns[-1]
camera_keys = list(dataset.meta.camera_keys)
def topic_for(key: str) -> str:
name = key[len(OBS_PREFIX) :] if str(key).startswith(OBS_PREFIX) else str(key)
return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}"
def emit_frame(i: int) -> None:
"""Log every channel for frame ``i`` stamped at its dataset timestamp."""
sample = dataset[i]
log_time = times_ns[i]
for key in camera_keys:
arr = sample.get(key)
if arr is None:
continue
arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr)
if np.issubdtype(arr.dtype, np.floating):
arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
_log_foxglove_image(topic_for(key), key, arr, compress_images=compress_images, time_ns=log_time)
_log_foxglove_scalars(f"/{OBS_STR}/state", _frame_to_scalars(sample, OBS_STATE), log_time=log_time)
_log_foxglove_scalars(f"/{ACTION}/state", _frame_to_scalars(sample, ACTION), log_time=log_time)
episode_scalars = {}
for feat, label in ((DONE, "done"), (REWARD, "reward"), (_SUCCESS, "success")):
v = sample.get(feat)
if v is not None:
episode_scalars[label] = float(v)
_log_foxglove_scalars("/episode/state", episode_scalars, log_time=log_time)
lock = threading.Lock()
stop_event = threading.Event()
server_holder: dict = {}
# Shared playback state, guarded by ``lock``.
state = {"status": PlaybackStatus.Paused, "cursor": first_ns, "speed": 1.0, "last_idx": -1}
def index_at(t_ns: int) -> int:
return max(0, min(n_frames - 1, bisect.bisect_right(times_ns, t_ns) - 1))
class _PlaybackListener(ServerListener):
def on_playback_control_request(self, req: PlaybackControlRequest):
emit_idx = None
with lock:
did_seek = False
if req.seek_time is not None:
cursor = max(first_ns, min(last_ns, req.seek_time))
state["cursor"] = cursor
emit_idx = state["last_idx"] = index_at(cursor)
did_seek = True
if req.playback_speed and req.playback_speed > 0:
state["speed"] = req.playback_speed
if req.playback_command == PlaybackCommand.Play:
# Restarting from the end replays from the beginning.
if state["cursor"] >= last_ns:
state["cursor"] = first_ns
emit_idx = state["last_idx"] = 0
did_seek = True
state["status"] = PlaybackStatus.Playing
elif req.playback_command == PlaybackCommand.Pause:
state["status"] = PlaybackStatus.Paused
status, cursor, speed = state["status"], state["cursor"], state["speed"]
request_id = req.request_id or ""
if emit_idx is not None:
emit_frame(emit_idx)
return PlaybackState(status, cursor, speed, did_seek, request_id)
server = foxglove.start_server(
name=f"{dataset.repo_id}/episode_{episode_index}",
host=host,
port=port,
capabilities=[Capability.PlaybackControl, Capability.Time],
server_listener=_PlaybackListener(),
playback_time_range=(first_ns, last_ns),
)
server_holder["server"] = server
def playback_loop() -> None:
prev = time.monotonic()
while not stop_event.is_set():
time.sleep(1.0 / 60.0)
with lock:
now = time.monotonic()
dt = now - prev
prev = now
if state["status"] != PlaybackStatus.Playing:
continue
cursor = state["cursor"] + int(dt * 1e9 * state["speed"])
start_idx = state["last_idx"] + 1
if cursor >= last_ns:
cursor, target, ended = last_ns, n_frames - 1, True
else:
target, ended = index_at(cursor), False
state["cursor"] = cursor
state["last_idx"] = max(state["last_idx"], target)
if ended:
state["status"] = PlaybackStatus.Ended
speed = state["speed"]
for i in range(start_idx, target + 1):
emit_frame(i)
server.broadcast_time(cursor)
if ended:
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Ended, cursor, speed, False, ""))
thread = threading.Thread(target=playback_loop, name="foxglove-playback", daemon=True)
thread.start()
# Emit the first frame so channels are advertised and the viewer isn't blank before playback.
emit_frame(0)
with lock:
state["last_idx"] = 0
server.broadcast_time(first_ns)
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Paused, first_ns, 1.0, True, ""))
print(f"Foxglove server running. Connect the Foxglove app to ws://{host}:{port}")
print("Use the playback controls in Foxglove to play/pause and scrub the episode. Ctrl-C to exit.")
try:
while not stop_event.is_set():
time.sleep(0.5)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
finally:
stop_event.set()
thread.join(timeout=2.0)
server.stop()
_foxglove_channels.clear()
# ── Backend-agnostic dispatch ─────────────────────────────────────────────
# These let callers select a visualization backend at runtime via a string
# (e.g. a `--display_mode` CLI flag) without branching on the backend everywhere.