mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 00:27:15 +00:00
add foxglove option to dataset viz
This commit is contained in:
committed by
CarolinePascal
parent
23817e0ab0
commit
77c9983a7d
@@ -59,6 +59,18 @@ distant$ lerobot-dataset-viz \
|
||||
local$ rerun rerun+http://IP:GRPC_PORT/proxy
|
||||
```
|
||||
|
||||
- Visualize data in Foxglove with a seekable, scrubbable timeline:
|
||||
```
|
||||
local$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--display-mode foxglove
|
||||
|
||||
# then open the Foxglove app and connect to ws://127.0.0.1:8765
|
||||
```
|
||||
This starts a Foxglove WebSocket server that serves the episode on demand from the on-disk dataset,
|
||||
so you can play/pause and scrub anywhere in the episode using Foxglove's playback controls.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -150,8 +162,26 @@ def visualize_dataset(
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
display_mode: str = "rerun",
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
**kwargs,
|
||||
) -> Path | None:
|
||||
if display_mode == "foxglove":
|
||||
if save:
|
||||
logging.warning("--save is ignored with --display-mode foxglove (no .rrd file is written).")
|
||||
from lerobot.utils.visualization_utils import serve_foxglove_dataset_playback
|
||||
|
||||
logging.info("Starting Foxglove server")
|
||||
serve_foxglove_dataset_playback(
|
||||
dataset,
|
||||
episode_index,
|
||||
host=host,
|
||||
port=port,
|
||||
compress_images=display_compressed_images,
|
||||
)
|
||||
return None
|
||||
|
||||
if save:
|
||||
assert output_dir is not None, (
|
||||
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||
@@ -351,7 +381,31 @@ def main():
|
||||
parser.add_argument(
|
||||
"--display-compressed-images",
|
||||
action="store_true",
|
||||
help="If set, display compressed images in Rerun instead of uncompressed ones.",
|
||||
help="If set, display compressed (JPEG) images instead of uncompressed ones.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--display-mode",
|
||||
type=str,
|
||||
default="rerun",
|
||||
choices=["rerun", "foxglove"],
|
||||
help=(
|
||||
"Visualization backend. 'rerun' uses the Rerun viewer (--mode/--save/--*-port apply). "
|
||||
"'foxglove' starts a Foxglove WebSocket server that serves the episode as a seekable, "
|
||||
"scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--port)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8765,
|
||||
help="Port to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||
from .constants import ACTION, ACTION_PREFIX, DONE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
|
||||
from .import_utils import require_package
|
||||
|
||||
# Module-level Foxglove state. A single WebSocket server is shared for the
|
||||
@@ -171,11 +171,15 @@ _SCALARS_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
def _log_foxglove_scalars(topic: str, values: dict[str, float]) -> None:
|
||||
def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int | None = None) -> None:
|
||||
"""Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`.
|
||||
|
||||
``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of
|
||||
``{label, value}`` objects. Insertion order is preserved so series stay stable across messages.
|
||||
|
||||
``log_time`` is the message time in nanoseconds. When ``None`` the server's receive time is used
|
||||
(correct for live streaming); dataset playback passes the frame's dataset timestamp so the
|
||||
Foxglove timeline reflects the recorded episode.
|
||||
"""
|
||||
|
||||
if not values:
|
||||
@@ -188,7 +192,67 @@ def _log_foxglove_scalars(topic: str, values: dict[str, float]) -> None:
|
||||
channel = _foxglove_channels[topic] = foxglove.Channel(
|
||||
topic, schema=_SCALARS_SCHEMA, message_encoding="json"
|
||||
)
|
||||
channel.log({"scalars": [{"label": label, "value": value} for label, value in values.items()]})
|
||||
msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]}
|
||||
if log_time is None:
|
||||
channel.log(msg)
|
||||
else:
|
||||
channel.log(msg, log_time=log_time)
|
||||
|
||||
|
||||
def _log_foxglove_image(
|
||||
topic: str, frame_id: str, arr: np.ndarray, *, compress_images: bool, time_ns: int
|
||||
) -> None:
|
||||
"""Log an image on a cached per-topic channel, stamped at ``time_ns`` (nanoseconds).
|
||||
|
||||
``arr`` may be HWC or CHW; CHW is transposed to HWC. ``time_ns`` sets both the message header
|
||||
timestamp and the channel ``log_time`` so the message lands at the right point on the Foxglove
|
||||
timeline (matching wall-clock for live streaming, or the dataset timestamp during playback).
|
||||
"""
|
||||
|
||||
from foxglove.channels import CompressedImageChannel, RawImageChannel
|
||||
from foxglove.messages import CompressedImage, RawImage, Timestamp
|
||||
|
||||
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
|
||||
|
||||
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
height, width = arr.shape[0], arr.shape[1]
|
||||
channels = 1 if arr.ndim == 2 else arr.shape[2]
|
||||
|
||||
if compress_images and channels == 3:
|
||||
import cv2
|
||||
|
||||
# Camera frames are RGB; cv2.imencode assumes BGR, so swap to keep colors correct.
|
||||
_, buf = cv2.imencode(".jpg", cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
|
||||
channel = _foxglove_channels.get(topic)
|
||||
if channel is None:
|
||||
channel = _foxglove_channels[topic] = CompressedImageChannel(topic=topic)
|
||||
channel.log(
|
||||
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"),
|
||||
log_time=time_ns,
|
||||
)
|
||||
return
|
||||
|
||||
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(channels)
|
||||
if encoding is None:
|
||||
return
|
||||
arr = np.ascontiguousarray(arr, dtype=np.uint8)
|
||||
channel = _foxglove_channels.get(topic)
|
||||
if channel is None:
|
||||
channel = _foxglove_channels[topic] = RawImageChannel(topic=topic)
|
||||
channel.log(
|
||||
RawImage(
|
||||
timestamp=timestamp,
|
||||
frame_id=frame_id,
|
||||
width=width,
|
||||
height=height,
|
||||
encoding=encoding,
|
||||
step=width * channels,
|
||||
data=arr.tobytes(),
|
||||
),
|
||||
log_time=time_ns,
|
||||
)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
@@ -295,53 +359,14 @@ def log_foxglove_data(
|
||||
"""
|
||||
|
||||
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
|
||||
from foxglove.channels import CompressedImageChannel, RawImageChannel
|
||||
from foxglove.messages import CompressedImage, RawImage, Timestamp
|
||||
|
||||
if _foxglove_server is None:
|
||||
raise RuntimeError("init_foxglove() must be called before log_foxglove_data().")
|
||||
|
||||
now = time.time_ns()
|
||||
timestamp = Timestamp(sec=now // 1_000_000_000, nsec=now % 1_000_000_000)
|
||||
|
||||
def log_image(topic: str, frame_id: str, arr: np.ndarray) -> None:
|
||||
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
height, width = arr.shape[0], arr.shape[1]
|
||||
channels = 1 if arr.ndim == 2 else arr.shape[2]
|
||||
|
||||
if compress_images and channels == 3:
|
||||
import cv2
|
||||
|
||||
# Camera frames are RGB; cv2.imencode assumes BGR, so swap to keep colors correct.
|
||||
_, buf = cv2.imencode(".jpg", cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
|
||||
channel = _foxglove_channels.get(topic)
|
||||
if channel is None:
|
||||
channel = _foxglove_channels[topic] = CompressedImageChannel(topic=topic)
|
||||
channel.log(
|
||||
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg")
|
||||
)
|
||||
return
|
||||
|
||||
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(channels)
|
||||
if encoding is None:
|
||||
return
|
||||
arr = np.ascontiguousarray(arr, dtype=np.uint8)
|
||||
channel = _foxglove_channels.get(topic)
|
||||
if channel is None:
|
||||
channel = _foxglove_channels[topic] = RawImageChannel(topic=topic)
|
||||
channel.log(
|
||||
RawImage(
|
||||
timestamp=timestamp,
|
||||
frame_id=frame_id,
|
||||
width=width,
|
||||
height=height,
|
||||
encoding=encoding,
|
||||
step=width * channels,
|
||||
data=arr.tobytes(),
|
||||
)
|
||||
)
|
||||
_log_foxglove_image(topic, frame_id, arr, compress_images=compress_images, time_ns=now)
|
||||
|
||||
if observation:
|
||||
obs_scalars: dict[str, float] = {}
|
||||
@@ -374,6 +399,197 @@ def log_foxglove_data(
|
||||
_log_foxglove_scalars(f"/{ACTION}/state", action_scalars)
|
||||
|
||||
|
||||
# ── Dataset playback over a Foxglove WebSocket server ─────────────────────
|
||||
# A LeRobotDataset is random-access on disk, so rather than fire-and-forget a forward stream we
|
||||
# advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays
|
||||
# to in the Foxglove app. This relies on the SDK's PlaybackControl capability.
|
||||
|
||||
_SUCCESS = "next.success"
|
||||
|
||||
|
||||
def _frame_to_scalars(sample: dict, key: str) -> dict[str, float]:
|
||||
"""Flatten a frame's vector/scalar feature ``key`` into ``{label: value}`` entries.
|
||||
|
||||
Vectors are expanded to ``<i>`` labels (one series per dimension); a scalar becomes a single
|
||||
entry. Missing or ``None`` features yield an empty mapping.
|
||||
"""
|
||||
|
||||
v = sample.get(key)
|
||||
if v is None:
|
||||
return {}
|
||||
arr = v.numpy() if hasattr(v, "numpy") else np.asarray(v)
|
||||
if arr.ndim == 0:
|
||||
return {"0": float(arr)}
|
||||
return {str(i): float(x) for i, x in enumerate(arr.flatten())}
|
||||
|
||||
|
||||
def serve_foxglove_dataset_playback(
|
||||
dataset,
|
||||
episode_index: int,
|
||||
*,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
compress_images: bool = False,
|
||||
) -> None:
|
||||
"""Serve a single dataset episode to Foxglove as a seekable, scrubbable timeline.
|
||||
|
||||
Starts a Foxglove WebSocket server advertising the ``PlaybackControl`` capability over the
|
||||
episode's time range. The Foxglove app drives play/pause/seek/speed; a background thread and a
|
||||
``ServerListener`` read frames from the on-disk ``dataset`` on demand and log them stamped at
|
||||
their dataset timestamps, so the user can scrub anywhere in the episode. Blocks until interrupted.
|
||||
|
||||
Args:
|
||||
dataset: A ``LeRobotDataset`` loaded for the single episode to visualize.
|
||||
episode_index: Index of the episode being visualized (used only for the session name).
|
||||
host: Host interface to bind the WebSocket server to.
|
||||
port: Port to bind the WebSocket server to.
|
||||
compress_images: Whether to JPEG-compress camera frames before logging.
|
||||
"""
|
||||
|
||||
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
|
||||
import bisect
|
||||
import threading
|
||||
|
||||
import foxglove
|
||||
from foxglove.websocket import (
|
||||
Capability,
|
||||
PlaybackCommand,
|
||||
PlaybackControlRequest,
|
||||
PlaybackState,
|
||||
PlaybackStatus,
|
||||
ServerListener,
|
||||
)
|
||||
|
||||
# Per-frame timestamps in nanoseconds (read straight from the table, no video decode).
|
||||
times_ns = [int(round(float(t) * 1e9)) for t in dataset.hf_dataset["timestamp"]]
|
||||
n_frames = len(times_ns)
|
||||
if n_frames == 0:
|
||||
raise ValueError("Cannot visualize an empty episode.")
|
||||
first_ns, last_ns = times_ns[0], times_ns[-1]
|
||||
camera_keys = list(dataset.meta.camera_keys)
|
||||
|
||||
def topic_for(key: str) -> str:
|
||||
name = key[len(OBS_PREFIX) :] if str(key).startswith(OBS_PREFIX) else str(key)
|
||||
return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}"
|
||||
|
||||
def emit_frame(i: int) -> None:
|
||||
"""Log every channel for frame ``i`` stamped at its dataset timestamp."""
|
||||
sample = dataset[i]
|
||||
log_time = times_ns[i]
|
||||
for key in camera_keys:
|
||||
arr = sample.get(key)
|
||||
if arr is None:
|
||||
continue
|
||||
arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr)
|
||||
if np.issubdtype(arr.dtype, np.floating):
|
||||
arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
|
||||
_log_foxglove_image(topic_for(key), key, arr, compress_images=compress_images, time_ns=log_time)
|
||||
_log_foxglove_scalars(f"/{OBS_STR}/state", _frame_to_scalars(sample, OBS_STATE), log_time=log_time)
|
||||
_log_foxglove_scalars(f"/{ACTION}/state", _frame_to_scalars(sample, ACTION), log_time=log_time)
|
||||
episode_scalars = {}
|
||||
for feat, label in ((DONE, "done"), (REWARD, "reward"), (_SUCCESS, "success")):
|
||||
v = sample.get(feat)
|
||||
if v is not None:
|
||||
episode_scalars[label] = float(v)
|
||||
_log_foxglove_scalars("/episode/state", episode_scalars, log_time=log_time)
|
||||
|
||||
lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
server_holder: dict = {}
|
||||
# Shared playback state, guarded by ``lock``.
|
||||
state = {"status": PlaybackStatus.Paused, "cursor": first_ns, "speed": 1.0, "last_idx": -1}
|
||||
|
||||
def index_at(t_ns: int) -> int:
|
||||
return max(0, min(n_frames - 1, bisect.bisect_right(times_ns, t_ns) - 1))
|
||||
|
||||
class _PlaybackListener(ServerListener):
|
||||
def on_playback_control_request(self, req: PlaybackControlRequest):
|
||||
emit_idx = None
|
||||
with lock:
|
||||
did_seek = False
|
||||
if req.seek_time is not None:
|
||||
cursor = max(first_ns, min(last_ns, req.seek_time))
|
||||
state["cursor"] = cursor
|
||||
emit_idx = state["last_idx"] = index_at(cursor)
|
||||
did_seek = True
|
||||
if req.playback_speed and req.playback_speed > 0:
|
||||
state["speed"] = req.playback_speed
|
||||
if req.playback_command == PlaybackCommand.Play:
|
||||
# Restarting from the end replays from the beginning.
|
||||
if state["cursor"] >= last_ns:
|
||||
state["cursor"] = first_ns
|
||||
emit_idx = state["last_idx"] = 0
|
||||
did_seek = True
|
||||
state["status"] = PlaybackStatus.Playing
|
||||
elif req.playback_command == PlaybackCommand.Pause:
|
||||
state["status"] = PlaybackStatus.Paused
|
||||
status, cursor, speed = state["status"], state["cursor"], state["speed"]
|
||||
request_id = req.request_id or ""
|
||||
if emit_idx is not None:
|
||||
emit_frame(emit_idx)
|
||||
return PlaybackState(status, cursor, speed, did_seek, request_id)
|
||||
|
||||
server = foxglove.start_server(
|
||||
name=f"{dataset.repo_id}/episode_{episode_index}",
|
||||
host=host,
|
||||
port=port,
|
||||
capabilities=[Capability.PlaybackControl, Capability.Time],
|
||||
server_listener=_PlaybackListener(),
|
||||
playback_time_range=(first_ns, last_ns),
|
||||
)
|
||||
server_holder["server"] = server
|
||||
|
||||
def playback_loop() -> None:
|
||||
prev = time.monotonic()
|
||||
while not stop_event.is_set():
|
||||
time.sleep(1.0 / 60.0)
|
||||
with lock:
|
||||
now = time.monotonic()
|
||||
dt = now - prev
|
||||
prev = now
|
||||
if state["status"] != PlaybackStatus.Playing:
|
||||
continue
|
||||
cursor = state["cursor"] + int(dt * 1e9 * state["speed"])
|
||||
start_idx = state["last_idx"] + 1
|
||||
if cursor >= last_ns:
|
||||
cursor, target, ended = last_ns, n_frames - 1, True
|
||||
else:
|
||||
target, ended = index_at(cursor), False
|
||||
state["cursor"] = cursor
|
||||
state["last_idx"] = max(state["last_idx"], target)
|
||||
if ended:
|
||||
state["status"] = PlaybackStatus.Ended
|
||||
speed = state["speed"]
|
||||
for i in range(start_idx, target + 1):
|
||||
emit_frame(i)
|
||||
server.broadcast_time(cursor)
|
||||
if ended:
|
||||
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Ended, cursor, speed, False, ""))
|
||||
|
||||
thread = threading.Thread(target=playback_loop, name="foxglove-playback", daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Emit the first frame so channels are advertised and the viewer isn't blank before playback.
|
||||
emit_frame(0)
|
||||
with lock:
|
||||
state["last_idx"] = 0
|
||||
server.broadcast_time(first_ns)
|
||||
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Paused, first_ns, 1.0, True, ""))
|
||||
|
||||
print(f"Foxglove server running. Connect the Foxglove app to ws://{host}:{port}")
|
||||
print("Use the playback controls in Foxglove to play/pause and scrub the episode. Ctrl-C to exit.")
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
time.sleep(0.5)
|
||||
except KeyboardInterrupt:
|
||||
print("Ctrl-C received. Exiting.")
|
||||
finally:
|
||||
stop_event.set()
|
||||
thread.join(timeout=2.0)
|
||||
server.stop()
|
||||
_foxglove_channels.clear()
|
||||
|
||||
|
||||
# ── Backend-agnostic dispatch ─────────────────────────────────────────────
|
||||
# These let callers select a visualization backend at runtime via a string
|
||||
# (e.g. a `--display_mode` CLI flag) without branching on the backend everywhere.
|
||||
|
||||
Reference in New Issue
Block a user