refactor(viz): split files + autoplay + updated docs + added minimal tests

This commit is contained in:
Steven Palma
2026-07-01 13:42:37 +02:00
parent a33b165b12
commit eab78d882b
9 changed files with 1246 additions and 1011 deletions
+5 -5
View File
@@ -126,7 +126,7 @@ import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
from lerobot.utils.visualization_utils import init_visualization, log_visualization_data, shutdown_visualization
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
@@ -142,7 +142,7 @@ teleop_config = SO101LeaderConfig(
id="my_leader_arm",
)
init_rerun(session_name="teleoperation")
init_visualization("rerun", session_name="teleoperation") # pass "foxglove" to stream to Foxglove instead
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
@@ -158,7 +158,7 @@ while True:
observation = robot.get_observation()
action = teleop_device.get_action()
robot.send_action(action)
log_rerun_data(observation=observation, action=action)
log_visualization_data("rerun", observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
@@ -223,7 +223,7 @@ from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_visualization
from lerobot.scripts.lerobot_record import record_loop
from lerobot.processor import make_default_processors
@@ -270,7 +270,7 @@ def main():
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
init_visualization("rerun", session_name="recording")
# Connect the robot and teleoperator
robot.connect()
+2
View File
@@ -265,6 +265,8 @@ lerobot-dataset-viz \
Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode.
To use [Foxglove](https://foxglove.dev) instead of Rerun, install the extra add `--display-mode foxglove`. This starts a WebSocket server (connect the Foxglove app to `ws://127.0.0.1:8765`) that serves the episode as a seekable timeline you can play/pause and scrub.
For advanced usage—including visualizing datasets stored on a remote server—run:
```bash
+12 -1
View File
@@ -164,12 +164,13 @@ def visualize_dataset(
display_compressed_images: bool = False,
display_mode: str = "rerun",
host: str = "127.0.0.1",
autoplay: bool = True,
**kwargs,
) -> Path | None:
if display_mode == "foxglove":
if save:
logging.warning("--save is ignored with --display-mode foxglove (no .rrd file is written).")
from lerobot.utils.visualization_utils import serve_foxglove_dataset_playback
from lerobot.utils.foxglove_visualization import serve_foxglove_dataset_playback
logging.info("Starting Foxglove server")
serve_foxglove_dataset_playback(
@@ -178,6 +179,7 @@ def visualize_dataset(
host=host,
port=web_port if web_port is not None else 8765,
compress_images=display_compressed_images,
autoplay=autoplay,
)
return None
@@ -408,6 +410,15 @@ def main():
"(127.0.0.1 for local only, 0.0.0.0 for all interfaces)."
),
)
parser.add_argument(
"--no-autoplay",
dest="autoplay",
action="store_false",
help=(
"For `--display-mode foxglove`: don't start playing automatically when a client "
"connects; wait for play to be pressed in the Foxglove app instead."
),
)
args = parser.parse_args()
kwargs = vars(args)
+610
View File
@@ -0,0 +1,610 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Foxglove visualization backend.
Live control-loop streaming (:func:`log_foxglove_data`) and seekable dataset playback
(:func:`serve_foxglove_dataset_playback`) over a Foxglove WebSocket server. Callers usually select a
backend at runtime through the dispatch in :mod:`lerobot.utils.visualization_utils` rather than
importing from here directly. Requires the ``foxglove`` extra (``pip install 'lerobot[foxglove]'``).
"""
import logging
import numbers
import time
import cv2
import numpy as np
from lerobot.types import RobotAction, RobotObservation
from .constants import (
ACTION,
ACTION_PREFIX,
DONE,
OBS_IMAGES,
OBS_PREFIX,
OBS_STATE,
OBS_STR,
REWARD,
SUCCESS,
TRUNCATED,
)
from .import_utils import require_package
# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}``
# pairs rather than one field per feature, so the same schema fits any robot regardless of which
# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name
# each series automatically, so a single filtered path plots every feature, e.g.
# ``/observation/state.scalars[:]``.
_SCALARS_SCHEMA = {
"type": "object",
"title": "lerobot.Scalars",
"properties": {
"scalars": {
"type": "array",
"items": {
"type": "object",
"properties": {
"label": {"type": "string"},
"value": {"type": "number"},
},
},
}
},
}
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def init_foxglove(host: str = "127.0.0.1", port: int | None = 8765) -> None:
"""
Starts a Foxglove WebSocket server for visualizing the control loop.
Connect to it from the Foxglove app at ``ws://<host>:<port>``. Calling this
more than once is a no-op while a server is already running.
Args:
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to (defaults to 8765).
"""
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
import foxglove
# Live-stream state lives as attributes on ``log_foxglove_data``:
# ``.server`` is the shared WebSocket server and
# ``.channels`` caches one Foxglove channel per topic
if getattr(log_foxglove_data, "server", None) is not None:
return
log_foxglove_data.server = foxglove.start_server(host=host, port=port or 8765)
log_foxglove_data.channels = {}
def shutdown_foxglove() -> None:
"""Stops the Foxglove WebSocket server and clears cached channels."""
server = getattr(log_foxglove_data, "server", None)
if server is not None:
server.stop()
log_foxglove_data.server = None
log_foxglove_data.channels = {}
def _foxglove_safe_name(name: str) -> str:
"""Replace ``.`` with ``_`` so a feature name is a single Foxglove topic-path segment.
Foxglove treats ``.`` as a path separator, so an unsanitized name like ``observation.images.front``
would split into nested segments instead of naming one topic.
"""
return name.replace(".", "_")
def _foxglove_topic(key: str, *, is_image: bool = False) -> str:
"""Build the Foxglove topic for a feature ``key``.
Camera features map to a per-source image topic (``/observation/images/<name>``); scalar features
share one aggregate topic per source: ``/observation/state`` for observations, ``/action/state``
for actions.
"""
if is_image:
name = str(key)
for prefix in (f"{OBS_IMAGES}.", OBS_PREFIX):
if name.startswith(prefix):
name = name[len(prefix) :]
break
return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}"
source = ACTION if (str(key).startswith(ACTION_PREFIX) or str(key) == ACTION) else OBS_STR
return f"/{source}/state"
def _log_foxglove_scalars(
topic: str, values: dict[str, float], *, channels: dict | None = None, log_time: int | None = None
) -> None:
"""Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`.
``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of
``{label, value}`` objects. Insertion order is preserved so series stay stable across messages.
``channels`` is the per-topic channel cache to reuse (defaults to the live-stream cache on
:func:`log_foxglove_data`; dataset playback passes its own local cache to stay self-contained).
``log_time`` is the message time in nanoseconds; when ``None`` the server's receive time is used.
"""
if not values:
return
import foxglove
if channels is None:
channels = log_foxglove_data.channels
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = foxglove.Channel(topic, schema=_SCALARS_SCHEMA, message_encoding="json")
msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]}
if log_time is None:
channel.log(msg)
else:
channel.log(msg, log_time=log_time)
def _labeled_scalars(name: str, values, labels: list[str] | None = None) -> dict[str, float]:
"""Expand a 1D sequence into ``{label: value}`` entries with a consistent fallback."""
flat = [float(v) for v in values]
if labels is None or len(labels) != len(flat):
labels = [f"{name}_{i}" for i in range(len(flat))]
return dict(zip(labels, flat, strict=True))
def _log_foxglove_image(
topic: str,
frame_id: str,
arr: np.ndarray,
*,
compress_images: bool,
channels: dict | None = None,
log_time: int | None = None,
) -> None:
"""Log an image on a cached per-topic channel.
``arr`` may be HWC or CHW (CHW is transposed to HWC) and any dtype; floating-point images are
assumed normalized to [0, 1] and scaled to uint8. With ``compress_images`` set, grayscale (1ch)
and color (3ch) frames are JPEG-encoded, while 4-channel (RGBA) frames are always sent raw.
``channels`` is the per-topic channel cache to reuse (see :func:`_log_foxglove_scalars`).
``log_time`` is the message time in nanoseconds; when ``None`` the server's receive time is used.
It is also written to the message header timestamp.
"""
from foxglove.channels import CompressedImageChannel, RawImageChannel
from foxglove.messages import CompressedImage, RawImage, Timestamp
if channels is None:
channels = log_foxglove_data.channels
time_ns = time.time_ns() if log_time is None else log_time
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
log_kwargs = {} if log_time is None else {"log_time": log_time}
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if np.issubdtype(arr.dtype, np.floating):
arr = (arr * 255.0).clip(0, 255)
arr = np.ascontiguousarray(arr, dtype=np.uint8)
height, width = arr.shape[0], arr.shape[1]
n_channels = 1 if arr.ndim == 2 else arr.shape[2]
if compress_images and n_channels in (1, 3):
buf_src = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) if n_channels == 3 else arr
_, buf = cv2.imencode(".jpg", buf_src)
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = CompressedImageChannel(topic=topic)
channel.log(
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"),
**log_kwargs,
)
return
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(n_channels)
if encoding is None:
logging.warning(
"Foxglove: skipping image on topic '%s' with unsupported shape %s (%d channels); "
"expected 1 (mono8), 3 (rgb8), or 4 (rgba8) channels.",
topic,
tuple(arr.shape),
n_channels,
)
return
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = RawImageChannel(topic=topic)
channel.log(
RawImage(
timestamp=timestamp,
frame_id=frame_id,
width=width,
height=height,
encoding=encoding,
step=width * n_channels,
data=arr.tobytes(),
),
**log_kwargs,
)
def log_foxglove_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to a Foxglove WebSocket server for real-time visualization.
Mirrors ``log_rerun_data`` but emits Foxglove messages over the server started by
:func:`init_foxglove`. Data is mapped as follows:
- Scalars (and elements of 1D arrays) are accumulated per source and logged on the
``/observation/state`` and ``/action/state`` topics as typed JSON messages using the static
``lerobot.Scalars`` schema: a ``scalars`` array of ``{label, value}`` objects (see
:data:`_SCALARS_SCHEMA`). The ``label`` field lets Foxglove name each series automatically, so
``/observation/state.scalars[:].value`` plots every feature at once.
- 3D NumPy arrays that resemble images are transposed from CHW to HWC when needed and logged on a
per-source topic (e.g. ``/observation/images/front``) as a ``RawImage`` (or a JPEG
``CompressedImage`` when ``compress_images`` is True).
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to JPEG-compress images before logging to save bandwidth in exchange
for CPU and quality.
"""
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
if getattr(log_foxglove_data, "server", None) is None:
raise RuntimeError("init_foxglove() must be called before log_foxglove_data().")
now = time.time_ns()
if observation:
obs_scalars: dict[str, float] = {}
for k, v in observation.items():
if v is None:
continue
key = k[len(OBS_PREFIX) :] if str(k).startswith(OBS_PREFIX) else str(k)
if _is_scalar(v):
obs_scalars[key] = float(v)
elif isinstance(v, np.ndarray):
if v.ndim == 1:
obs_scalars.update(_labeled_scalars(key, v))
else:
_log_foxglove_image(
_foxglove_topic(k, is_image=True),
key,
v,
compress_images=compress_images,
log_time=now,
)
_log_foxglove_scalars(_foxglove_topic(OBS_STATE), obs_scalars, log_time=now)
if action:
action_scalars: dict[str, float] = {}
for k, v in action.items():
if v is None:
continue
key = k[len(ACTION_PREFIX) :] if str(k).startswith(ACTION_PREFIX) else str(k)
if _is_scalar(v):
action_scalars[key] = float(v)
elif isinstance(v, np.ndarray):
action_scalars.update(_labeled_scalars(key, v.flatten()))
_log_foxglove_scalars(_foxglove_topic(ACTION), action_scalars, log_time=now)
# ── Dataset playback over a Foxglove WebSocket server ─────────────────────
# A LeRobotDataset is random-access on disk, so rather than fire-and-forget a forward stream we
# advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays
# to in the Foxglove app. This relies on the SDK's PlaybackControl capability.
def _feature_dim_names(feature: dict | None) -> list[str] | None:
"""Best-effort per-dimension series labels for a 1D feature, or ``None`` to fall back to indices.
LeRobot records a feature's ``names`` inconsistently: a flat list (``["x", "y"]``), a category
mapping (``{"motors": ["motor_0", "motor_1"]}``), or a name->index mapping
(``{"delta_x": 0, "delta_y": 1}``). Each is handled, but labels are only returned when their count
matches the feature's 1D shape, so a malformed/mismatched ``names`` can't silently mislabel series.
"""
if not feature:
return None
shape = feature.get("shape")
dim = shape[0] if shape and len(shape) == 1 else None
names = feature.get("names")
labels: list[str] | None = None
if isinstance(names, dict):
values = list(names.values())
if values and all(isinstance(v, (list, tuple)) for v in values):
labels = [str(n) for group in values for n in group]
elif values and all(isinstance(v, int) and not isinstance(v, bool) for v in values):
labels = [name for name, _ in sorted(names.items(), key=lambda kv: kv[1])]
elif isinstance(names, (list, tuple)):
labels = [str(n) for n in names]
if labels is not None and dim is not None and len(labels) == dim:
return labels
return None
def _frame_to_scalars(sample: dict, key: str, labels: list[str] | None = None) -> dict[str, float]:
"""Flatten a frame's vector/scalar feature ``key`` into ``{label: value}`` entries.
``labels`` provides one name per dimension (from the dataset's feature metadata); when absent or
the wrong length, dimensions fall back to ``{name}_{i}`` (the short feature name), matching the
live stream so series names agree. A scalar feature becomes a single entry. Missing or ``None``
features yield an empty mapping.
"""
v = sample.get(key)
if v is None:
return {}
arr = v.numpy() if hasattr(v, "numpy") else np.asarray(v)
if key.startswith(OBS_PREFIX):
name = key[len(OBS_PREFIX) :]
elif key.startswith(ACTION_PREFIX):
name = key[len(ACTION_PREFIX) :]
else:
name = key
if arr.ndim == 0:
return {name: float(arr)}
return _labeled_scalars(name, arr.flatten(), labels)
def serve_foxglove_dataset_playback(
dataset,
episode_index: int,
*,
host: str = "127.0.0.1",
port: int = 8765,
compress_images: bool = False,
autoplay: bool = True,
) -> None:
"""Serve a single dataset episode to Foxglove as a seekable, scrubbable timeline.
Starts a Foxglove WebSocket server advertising the ``PlaybackControl`` capability over the
episode's time range. The Foxglove app drives play/pause/seek/speed; a background thread and a
``ServerListener`` read frames from the on-disk ``dataset`` on demand and log them stamped at
their dataset timestamps, so the user can scrub anywhere in the episode. Blocks until interrupted.
Args:
dataset: A ``LeRobotDataset`` loaded for the single episode to visualize.
episode_index: Index of the episode being visualized (used only for the session name).
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to.
compress_images: Whether to JPEG-compress camera frames before logging.
autoplay: If True, start playing automatically as soon as a client connects, instead of
waiting for the user to press play in the Foxglove app.
"""
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
import bisect
import threading
import foxglove
from foxglove.websocket import (
Capability,
PlaybackCommand,
PlaybackControlRequest,
PlaybackState,
PlaybackStatus,
ServerListener,
)
# Per-frame timestamps in nanoseconds (read straight from the table, no video decode).
times_ns = [int(round(float(t) * 1e9)) for t in dataset.hf_dataset["timestamp"]]
n_frames = len(times_ns)
if n_frames == 0:
raise ValueError("Cannot visualize an empty episode.")
first_ns, last_ns = times_ns[0], times_ns[-1]
camera_keys = list(dataset.meta.camera_keys)
# Per-dimension series labels from the dataset metadata (e.g. joint names), computed once.
scalar_labels = {
OBS_STATE: _feature_dim_names(dataset.meta.features.get(OBS_STATE)),
ACTION: _feature_dim_names(dataset.meta.features.get(ACTION)),
}
# Local channel cache so the playback server is self-contained and doesn't touch the live-stream cache.
channels: dict = {}
def emit_frame(i: int) -> None:
"""Log every channel for frame ``i`` stamped at its dataset timestamp."""
sample = dataset[i]
log_time = times_ns[i]
for key in camera_keys:
arr = sample.get(key)
if arr is None:
continue
arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr)
_log_foxglove_image(
_foxglove_topic(key, is_image=True),
key,
arr,
compress_images=compress_images,
channels=channels,
log_time=log_time,
)
_log_foxglove_scalars(
_foxglove_topic(OBS_STATE),
_frame_to_scalars(sample, OBS_STATE, scalar_labels[OBS_STATE]),
channels=channels,
log_time=log_time,
)
_log_foxglove_scalars(
_foxglove_topic(ACTION),
_frame_to_scalars(sample, ACTION, scalar_labels[ACTION]),
channels=channels,
log_time=log_time,
)
episode_scalars = {}
for feat, label in (
(DONE, "done"),
(TRUNCATED, "truncated"),
(REWARD, "reward"),
(SUCCESS, "success"),
):
v = sample.get(feat)
if v is not None:
episode_scalars[label] = float(v)
_log_foxglove_scalars("/episode/state", episode_scalars, channels=channels, log_time=log_time)
lock = threading.Lock()
stop_event = threading.Event()
# Shared playback state, guarded by ``lock``. ``seek_idx`` is a one-shot request set by the
# listener and serviced by the playback loop, which is the *only* thread that emits frames (so
# concurrent random access into the on-disk dataset / video decoder never overlaps).
state = {
"status": PlaybackStatus.Paused,
"cursor": first_ns,
"speed": 1.0,
"last_idx": -1,
"seek_idx": None,
}
def index_at(t_ns: int) -> int:
return max(0, min(n_frames - 1, bisect.bisect_right(times_ns, t_ns) - 1))
# One-shot latch so autoplay fires only on the first client subscription.
autoplay_started = threading.Event()
class _PlaybackListener(ServerListener):
def on_subscribe(self, client, channel):
# Start playing automatically once a client actually connects (subscribes). Using the
# subscribe hook, rather than starting in Playing up front, means the timeline doesn't
# advance before anyone is watching. Fires once; the user can still pause/seek after.
if not autoplay:
return
with lock:
if autoplay_started.is_set() or state["status"] != PlaybackStatus.Paused:
return
autoplay_started.set()
state["status"] = PlaybackStatus.Playing
cursor, speed = state["cursor"], state["speed"]
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Playing, cursor, speed, False, ""))
def on_playback_control_request(self, req: PlaybackControlRequest):
# Only mutate state here; the playback loop performs all frame emission.
with lock:
did_seek = False
if req.seek_time is not None:
cursor = max(first_ns, min(last_ns, req.seek_time))
state["cursor"] = cursor
state["last_idx"] = state["seek_idx"] = index_at(cursor)
did_seek = True
if req.playback_speed and req.playback_speed > 0:
state["speed"] = req.playback_speed
if req.playback_command == PlaybackCommand.Play:
# Restarting from the end replays from the beginning.
if state["cursor"] >= last_ns:
state["cursor"] = first_ns
state["last_idx"] = state["seek_idx"] = 0
did_seek = True
state["status"] = PlaybackStatus.Playing
elif req.playback_command == PlaybackCommand.Pause:
state["status"] = PlaybackStatus.Paused
status, cursor, speed = state["status"], state["cursor"], state["speed"]
request_id = req.request_id or ""
return PlaybackState(status, cursor, speed, did_seek, request_id)
server = foxglove.start_server(
name=f"{dataset.repo_id}/episode_{episode_index}",
host=host,
port=port,
capabilities=[Capability.PlaybackControl, Capability.Time],
server_listener=_PlaybackListener(),
playback_time_range=(first_ns, last_ns),
)
def playback_loop() -> None:
# Cap how far the cursor may advance in a single tick. A slow frame decode (or any stall)
# would otherwise make ``dt`` huge and produce one enormous catch-up batch; clamping it makes
# playback trail wall-clock under a slow decoder while each tick emits a bounded frame range.
max_tick_dt_s = 0.25
prev = time.monotonic()
while not stop_event.is_set():
time.sleep(1.0 / 60.0)
ended = False
speed = 1.0
with lock:
now = time.monotonic()
dt = min(now - prev, max_tick_dt_s)
prev = now
# A queued seek is always serviced, even while paused, so scrubbing updates the view.
work = []
seek_idx = state["seek_idx"]
if seek_idx is not None:
state["seek_idx"] = None
work.append(seek_idx)
if state["status"] == PlaybackStatus.Playing:
cursor = state["cursor"] + int(dt * 1e9 * state["speed"])
start_idx = state["last_idx"] + 1
if cursor >= last_ns:
cursor, target, ended = last_ns, n_frames - 1, True
else:
target = index_at(cursor)
state["cursor"] = cursor
work.extend(range(start_idx, target + 1))
# cursor only grows while playing (seeks reset last_idx in the listener), so
# target >= last_idx here; a plain assignment is correct and clearer than max().
state["last_idx"] = target
if ended:
state["status"] = PlaybackStatus.Ended
if not work:
continue
cursor, speed = state["cursor"], state["speed"]
# Emit outside the lock; this is the only thread that calls emit_frame. Re-check
# stop_event between frames so shutdown stays responsive even mid-batch.
for i in work:
if stop_event.is_set():
break
emit_frame(i)
server.broadcast_time(cursor)
if ended:
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Ended, cursor, speed, False, ""))
# Emit the first frame so channels are advertised (done before the loop starts, so emission stays
# single-threaded). Late-connecting clients re-receive frames once they seek/play.
emit_frame(0)
with lock:
state["last_idx"] = 0
server.broadcast_time(first_ns)
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Paused, first_ns, 1.0, True, ""))
thread = threading.Thread(target=playback_loop, name="foxglove-playback", daemon=True)
thread.start()
print(f"Foxglove server running. Connect the Foxglove app to ws://{host}:{port}")
print("Use the playback controls in Foxglove to play/pause and scrub the episode. Ctrl-C to exit.")
try:
while not stop_event.is_set():
time.sleep(0.5)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
finally:
stop_event.set()
thread.join(timeout=2.0)
server.stop()
channels.clear()
+184
View File
@@ -0,0 +1,184 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rerun visualization backend.
Live control-loop streaming to the Rerun viewer (:func:`log_rerun_data`). Callers usually select a
backend at runtime through the dispatch in :mod:`lerobot.utils.visualization_utils` rather than
importing from here directly. Requires the ``viz`` extra (``pip install 'lerobot[viz]'``).
"""
import numbers
import os
import numpy as np
from lerobot.types import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
from .import_utils import require_package
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def init_rerun(
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
) -> None:
"""
Initializes the Rerun SDK for visualizing the control loop.
Args:
session_name: Name of the Rerun session.
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
log_rerun_data.blueprint = None # Reset blueprint cache for new session
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
if ip and port:
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
else:
rr.spawn(memory_limit=memory_limit)
def shutdown_rerun() -> None:
"""Shuts down the Rerun SDK gracefully."""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
rr.rerun_shutdown()
def _build_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]):
"""Build a Rerun blueprint laying out camera images, observation and action scalars in separate views.
Camera images, observation and action scalars are arranged in a grid.
"""
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun.blueprint as rrb
views = [rrb.Spatial2DView(origin=path, name=path) for path in sorted(image_paths)]
if observation_paths:
views.append(rrb.TimeSeriesView(name="observation", contents=sorted(observation_paths)))
if action_paths:
views.append(rrb.TimeSeriesView(name="action", contents=sorted(action_paths)))
return rrb.Blueprint(rrb.Grid(*views))
def _ensure_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]) -> None:
"""Build and send the blueprint once, from the first observation and action data."""
if getattr(log_rerun_data, "blueprint", None) is not None:
return
if not (observation_paths or action_paths or image_paths):
return
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun as rr
blueprint = _build_blueprint(observation_paths, action_paths, image_paths)
log_rerun_data.blueprint = blueprint
rr.send_blueprint(blueprint)
def log_rerun_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to Rerun for real-time visualization.
This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately:
- Scalars values (floats, ints) are logged as `rr.Scalars`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every
dimension shares the same view instead of being split across one view per element.
- Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch.
Keys are automatically namespaced with "observation." or "action." if not already present.
On the first call, a blueprint is built and sent so observation and action scalars get separate
time-series views and each image gets its own spatial view.
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
observation_paths: set[str] = set()
action_paths: set[str] = set()
image_paths: set[str] = set()
if observation:
for k, v in observation.items():
if v is None:
continue
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
observation_paths.add(key)
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
rr.log(key, rr.Scalars(arr.astype(float)))
observation_paths.add(key)
else:
if arr.shape[-1] == 1:
img_entity = rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
image_paths.add(key)
if action:
for k, v in action.items():
if v is None:
continue
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
action_paths.add(key)
elif isinstance(v, np.ndarray):
# Flatten any (incl. higher-dimensional) array into a single batched Scalars
rr.log(key, rr.Scalars(v.reshape(-1).astype(float)))
action_paths.add(key)
_ensure_blueprint(observation_paths, action_paths, image_paths)
+9 -718
View File
@@ -12,732 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numbers
import os
import time
"""Backend-agnostic visualization dispatch.
import cv2
import numpy as np
Selects a visualization backend at runtime via a display-mode string (e.g. a ``--display_mode`` CLI
flag) so callers never branch on the backend. The concrete implementations live in
:mod:`lerobot.utils.rerun_visualization` and :mod:`lerobot.utils.foxglove_visualization`; importing
this module does not import ``rerun`` or ``foxglove`` (each backend imports its SDK lazily behind a
``require_package`` guard).
"""
from lerobot.types import RobotAction, RobotObservation
from .constants import (
ACTION,
ACTION_PREFIX,
DONE,
OBS_IMAGES,
OBS_PREFIX,
OBS_STATE,
OBS_STR,
REWARD,
SUCCESS,
TRUNCATED,
)
from .import_utils import require_package
from .foxglove_visualization import init_foxglove, log_foxglove_data, shutdown_foxglove
from .rerun_visualization import init_rerun, log_rerun_data, shutdown_rerun
# Visualization backends selectable at runtime via a display-mode string (e.g. a --display_mode flag).
VISUALIZATION_MODES = ("rerun", "foxglove")
# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}``
# pairs rather than one field per feature, so the same schema fits any robot regardless of which
# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name
# each series automatically, so a single filtered path plots every feature, e.g.
# ``/observation/state.scalars[:]``.
_SCALARS_SCHEMA = {
"type": "object",
"title": "lerobot.Scalars",
"properties": {
"scalars": {
"type": "array",
"items": {
"type": "object",
"properties": {
"label": {"type": "string"},
"value": {"type": "number"},
},
},
}
},
}
def init_rerun(
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
) -> None:
"""
Initializes the Rerun SDK for visualizing the control loop.
Args:
session_name: Name of the Rerun session.
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
log_rerun_data.blueprint = None # Reset blueprint cache for new session
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
if ip and port:
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
else:
rr.spawn(memory_limit=memory_limit)
def shutdown_rerun() -> None:
"""Shuts down the Rerun SDK gracefully."""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
rr.rerun_shutdown()
def init_foxglove(host: str = "127.0.0.1", port: int | None = 8765) -> None:
"""
Starts a Foxglove WebSocket server for visualizing the control loop.
Connect to it from the Foxglove app at ``ws://<host>:<port>``. Calling this
more than once is a no-op while a server is already running.
Args:
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to (defaults to 8765).
"""
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
import foxglove
# Live-stream state lives as attributes on ``log_foxglove_data``:
# ``.server`` is the shared WebSocket server and
# ``.channels`` caches one Foxglove channel per topic
if getattr(log_foxglove_data, "server", None) is not None:
return
log_foxglove_data.server = foxglove.start_server(host=host, port=port or 8765)
log_foxglove_data.channels = {}
def shutdown_foxglove() -> None:
"""Stops the Foxglove WebSocket server and clears cached channels."""
server = getattr(log_foxglove_data, "server", None)
if server is not None:
server.stop()
log_foxglove_data.server = None
log_foxglove_data.channels = {}
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def _build_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]):
"""Build a Rerun blueprint laying out camera images, observation and action scalars in separate views.
Camera images, observation and action scalars are arranged in a grid.
"""
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun.blueprint as rrb
views = [rrb.Spatial2DView(origin=path, name=path) for path in sorted(image_paths)]
if observation_paths:
views.append(rrb.TimeSeriesView(name="observation", contents=sorted(observation_paths)))
if action_paths:
views.append(rrb.TimeSeriesView(name="action", contents=sorted(action_paths)))
return rrb.Blueprint(rrb.Grid(*views))
def _ensure_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]) -> None:
"""Build and send the blueprint once, from the first observation and action data."""
if getattr(log_rerun_data, "blueprint", None) is not None:
return
if not (observation_paths or action_paths or image_paths):
return
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun as rr
blueprint = _build_blueprint(observation_paths, action_paths, image_paths)
log_rerun_data.blueprint = blueprint
rr.send_blueprint(blueprint)
def _foxglove_safe_name(name: str) -> str:
"""Replace ``.`` with ``_`` so a feature name is a single Foxglove topic-path segment.
Foxglove treats ``.`` as a path separator, so an unsanitized name like ``observation.images.front``
would split into nested segments instead of naming one topic.
"""
return name.replace(".", "_")
def _foxglove_topic(key: str, *, is_image: bool = False) -> str:
"""Build the Foxglove topic for a feature ``key``.
Camera features map to a per-source image topic (``/observation/images/<name>``); scalar features
share one aggregate topic per source: ``/observation/state`` for observations, ``/action/state``
for actions.
"""
if is_image:
name = str(key)
for prefix in (f"{OBS_IMAGES}.", OBS_PREFIX):
if name.startswith(prefix):
name = name[len(prefix) :]
break
return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}"
source = ACTION if (str(key).startswith(ACTION_PREFIX) or str(key) == ACTION) else OBS_STR
return f"/{source}/state"
def _log_foxglove_scalars(
topic: str, values: dict[str, float], *, channels: dict | None = None, log_time: int | None = None
) -> None:
"""Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`.
``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of
``{label, value}`` objects. Insertion order is preserved so series stay stable across messages.
``channels`` is the per-topic channel cache to reuse (defaults to the module-global cache used by
live streaming; dataset playback passes its own local cache to stay self-contained). ``log_time``
is the message time in nanoseconds; when ``None`` the server's receive time is used.
"""
if not values:
return
import foxglove
if channels is None:
channels = log_foxglove_data.channels
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = foxglove.Channel(topic, schema=_SCALARS_SCHEMA, message_encoding="json")
msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]}
if log_time is None:
channel.log(msg)
else:
channel.log(msg, log_time=log_time)
def _labeled_scalars(name: str, values, labels: list[str] | None = None) -> dict[str, float]:
"""Expand a 1D sequence into ``{label: value}`` entries with a consistent fallback."""
flat = [float(v) for v in values]
if labels is None or len(labels) != len(flat):
labels = [f"{name}_{i}" for i in range(len(flat))]
return dict(zip(labels, flat, strict=True))
def _log_foxglove_image(
topic: str,
frame_id: str,
arr: np.ndarray,
*,
compress_images: bool,
channels: dict | None = None,
log_time: int | None = None,
) -> None:
"""Log an image on a cached per-topic channel.
``arr`` may be HWC or CHW (CHW is transposed to HWC) and any dtype; floating-point images are
assumed normalized to [0, 1] and scaled to uint8. With ``compress_images`` set, grayscale (1ch)
and color (3ch) frames are JPEG-encoded, while 4-channel (RGBA) frames are always sent raw.
``channels`` is the per-topic channel cache to reuse (see :func:`_log_foxglove_scalars`).
``log_time`` is the message time in nanoseconds; when ``None`` the server's receive time is used.
It is also written to the message header timestamp.
"""
from foxglove.channels import CompressedImageChannel, RawImageChannel
from foxglove.messages import CompressedImage, RawImage, Timestamp
if channels is None:
channels = log_foxglove_data.channels
time_ns = time.time_ns() if log_time is None else log_time
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
log_kwargs = {} if log_time is None else {"log_time": log_time}
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if np.issubdtype(arr.dtype, np.floating):
arr = (arr * 255.0).clip(0, 255)
arr = np.ascontiguousarray(arr, dtype=np.uint8)
height, width = arr.shape[0], arr.shape[1]
n_channels = 1 if arr.ndim == 2 else arr.shape[2]
if compress_images and n_channels in (1, 3):
buf_src = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) if n_channels == 3 else arr
_, buf = cv2.imencode(".jpg", buf_src)
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = CompressedImageChannel(topic=topic)
channel.log(
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"),
**log_kwargs,
)
return
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(n_channels)
if encoding is None:
logging.warning(
"Foxglove: skipping image on topic '%s' with unsupported shape %s (%d channels); "
"expected 1 (mono8), 3 (rgb8), or 4 (rgba8) channels.",
topic,
tuple(arr.shape),
n_channels,
)
return
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = RawImageChannel(topic=topic)
channel.log(
RawImage(
timestamp=timestamp,
frame_id=frame_id,
width=width,
height=height,
encoding=encoding,
step=width * n_channels,
data=arr.tobytes(),
),
**log_kwargs,
)
def log_rerun_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to Rerun for real-time visualization.
This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately:
- Scalars values (floats, ints) are logged as `rr.Scalars`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every
dimension shares the same view instead of being split across one view per element.
- Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch.
Keys are automatically namespaced with "observation." or "action." if not already present.
On the first call, a blueprint is built and sent so observation and action scalars get separate
time-series views and each image gets its own spatial view.
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
observation_paths: set[str] = set()
action_paths: set[str] = set()
image_paths: set[str] = set()
if observation:
for k, v in observation.items():
if v is None:
continue
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
observation_paths.add(key)
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
rr.log(key, rr.Scalars(arr.astype(float)))
observation_paths.add(key)
else:
if arr.shape[-1] == 1:
img_entity = rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
image_paths.add(key)
if action:
for k, v in action.items():
if v is None:
continue
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
action_paths.add(key)
elif isinstance(v, np.ndarray):
# Flatten any (incl. higher-dimensional) array into a single batched Scalars
rr.log(key, rr.Scalars(v.reshape(-1).astype(float)))
action_paths.add(key)
_ensure_blueprint(observation_paths, action_paths, image_paths)
def log_foxglove_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to a Foxglove WebSocket server for real-time visualization.
Mirrors :func:`log_rerun_data` but emits Foxglove messages over the server started by
:func:`init_foxglove`. Data is mapped as follows:
- Scalars (and elements of 1D arrays) are accumulated per source and logged on the
``/observation/state`` and ``/action/state`` topics as typed JSON messages using the static
``lerobot.Scalars`` schema: a ``scalars`` array of ``{label, value}`` objects (see
:data:`_SCALARS_SCHEMA`). The ``label`` field lets Foxglove name each series automatically, so
``/observation/state.scalars[:].value`` plots every feature at once.
- 3D NumPy arrays that resemble images are transposed from CHW to HWC when needed and logged on a
per-source topic (e.g. ``/observation/images/front``) as a ``RawImage`` (or a JPEG
``CompressedImage`` when ``compress_images`` is True).
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to JPEG-compress images before logging to save bandwidth in exchange
for CPU and quality.
"""
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
if getattr(log_foxglove_data, "server", None) is None:
raise RuntimeError("init_foxglove() must be called before log_foxglove_data().")
now = time.time_ns()
if observation:
obs_scalars: dict[str, float] = {}
for k, v in observation.items():
if v is None:
continue
key = k[len(OBS_PREFIX) :] if str(k).startswith(OBS_PREFIX) else str(k)
if _is_scalar(v):
obs_scalars[key] = float(v)
elif isinstance(v, np.ndarray):
if v.ndim == 1:
obs_scalars.update(_labeled_scalars(key, v))
else:
_log_foxglove_image(
_foxglove_topic(k, is_image=True),
key,
v,
compress_images=compress_images,
log_time=now,
)
_log_foxglove_scalars(_foxglove_topic(OBS_STATE), obs_scalars, log_time=now)
if action:
action_scalars: dict[str, float] = {}
for k, v in action.items():
if v is None:
continue
key = k[len(ACTION_PREFIX) :] if str(k).startswith(ACTION_PREFIX) else str(k)
if _is_scalar(v):
action_scalars[key] = float(v)
elif isinstance(v, np.ndarray):
action_scalars.update(_labeled_scalars(key, v.flatten()))
_log_foxglove_scalars(_foxglove_topic(ACTION), action_scalars, log_time=now)
# ── Dataset playback over a Foxglove WebSocket server ─────────────────────
# A LeRobotDataset is random-access on disk, so rather than fire-and-forget a forward stream we
# advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays
# to in the Foxglove app. This relies on the SDK's PlaybackControl capability.
def _feature_dim_names(feature: dict | None) -> list[str] | None:
"""Best-effort per-dimension series labels for a 1D feature, or ``None`` to fall back to indices.
LeRobot records a feature's ``names`` inconsistently: a flat list (``["x", "y"]``), a category
mapping (``{"motors": ["motor_0", "motor_1"]}``), or a name->index mapping
(``{"delta_x": 0, "delta_y": 1}``). Each is handled, but labels are only returned when their count
matches the feature's 1D shape, so a malformed/mismatched ``names`` can't silently mislabel series.
"""
if not feature:
return None
shape = feature.get("shape")
dim = shape[0] if shape and len(shape) == 1 else None
names = feature.get("names")
labels: list[str] | None = None
if isinstance(names, dict):
values = list(names.values())
if values and all(isinstance(v, (list, tuple)) for v in values):
labels = [str(n) for group in values for n in group]
elif values and all(isinstance(v, int) and not isinstance(v, bool) for v in values):
labels = [name for name, _ in sorted(names.items(), key=lambda kv: kv[1])]
elif isinstance(names, (list, tuple)):
labels = [str(n) for n in names]
if labels is not None and dim is not None and len(labels) == dim:
return labels
return None
def _frame_to_scalars(sample: dict, key: str, labels: list[str] | None = None) -> dict[str, float]:
"""Flatten a frame's vector/scalar feature ``key`` into ``{label: value}`` entries.
``labels`` provides one name per dimension (from the dataset's feature metadata); when absent or
the wrong length, dimensions fall back to ``{name}_{i}`` (the short feature name), matching the
live stream so series names agree. A scalar feature becomes a single entry. Missing or ``None``
features yield an empty mapping.
"""
v = sample.get(key)
if v is None:
return {}
arr = v.numpy() if hasattr(v, "numpy") else np.asarray(v)
if key.startswith(OBS_PREFIX):
name = key[len(OBS_PREFIX) :]
elif key.startswith(ACTION_PREFIX):
name = key[len(ACTION_PREFIX) :]
else:
name = key
if arr.ndim == 0:
return {name: float(arr)}
return _labeled_scalars(name, arr.flatten(), labels)
def serve_foxglove_dataset_playback(
dataset,
episode_index: int,
*,
host: str = "127.0.0.1",
port: int = 8765,
compress_images: bool = False,
) -> None:
"""Serve a single dataset episode to Foxglove as a seekable, scrubbable timeline.
Starts a Foxglove WebSocket server advertising the ``PlaybackControl`` capability over the
episode's time range. The Foxglove app drives play/pause/seek/speed; a background thread and a
``ServerListener`` read frames from the on-disk ``dataset`` on demand and log them stamped at
their dataset timestamps, so the user can scrub anywhere in the episode. Blocks until interrupted.
Args:
dataset: A ``LeRobotDataset`` loaded for the single episode to visualize.
episode_index: Index of the episode being visualized (used only for the session name).
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to.
compress_images: Whether to JPEG-compress camera frames before logging.
"""
require_package("foxglove-sdk", extra="foxglove", import_name="foxglove")
import bisect
import threading
import foxglove
from foxglove.websocket import (
Capability,
PlaybackCommand,
PlaybackControlRequest,
PlaybackState,
PlaybackStatus,
ServerListener,
)
# Per-frame timestamps in nanoseconds (read straight from the table, no video decode).
times_ns = [int(round(float(t) * 1e9)) for t in dataset.hf_dataset["timestamp"]]
n_frames = len(times_ns)
if n_frames == 0:
raise ValueError("Cannot visualize an empty episode.")
first_ns, last_ns = times_ns[0], times_ns[-1]
camera_keys = list(dataset.meta.camera_keys)
# Per-dimension series labels from the dataset metadata (e.g. joint names), computed once.
scalar_labels = {
OBS_STATE: _feature_dim_names(dataset.meta.features.get(OBS_STATE)),
ACTION: _feature_dim_names(dataset.meta.features.get(ACTION)),
}
# Local channel cache so the playback server is self-contained and doesn't touch the module global.
channels: dict = {}
def emit_frame(i: int) -> None:
"""Log every channel for frame ``i`` stamped at its dataset timestamp."""
sample = dataset[i]
log_time = times_ns[i]
for key in camera_keys:
arr = sample.get(key)
if arr is None:
continue
arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr)
_log_foxglove_image(
_foxglove_topic(key, is_image=True),
key,
arr,
compress_images=compress_images,
channels=channels,
log_time=log_time,
)
_log_foxglove_scalars(
_foxglove_topic(OBS_STATE),
_frame_to_scalars(sample, OBS_STATE, scalar_labels[OBS_STATE]),
channels=channels,
log_time=log_time,
)
_log_foxglove_scalars(
_foxglove_topic(ACTION),
_frame_to_scalars(sample, ACTION, scalar_labels[ACTION]),
channels=channels,
log_time=log_time,
)
episode_scalars = {}
for feat, label in (
(DONE, "done"),
(TRUNCATED, "truncated"),
(REWARD, "reward"),
(SUCCESS, "success"),
):
v = sample.get(feat)
if v is not None:
episode_scalars[label] = float(v)
_log_foxglove_scalars("/episode/state", episode_scalars, channels=channels, log_time=log_time)
lock = threading.Lock()
stop_event = threading.Event()
# Shared playback state, guarded by ``lock``. ``seek_idx`` is a one-shot request set by the
# listener and serviced by the playback loop, which is the *only* thread that emits frames (so
# concurrent random access into the on-disk dataset / video decoder never overlaps).
state = {
"status": PlaybackStatus.Paused,
"cursor": first_ns,
"speed": 1.0,
"last_idx": -1,
"seek_idx": None,
}
def index_at(t_ns: int) -> int:
return max(0, min(n_frames - 1, bisect.bisect_right(times_ns, t_ns) - 1))
class _PlaybackListener(ServerListener):
def on_playback_control_request(self, req: PlaybackControlRequest):
# Only mutate state here; the playback loop performs all frame emission.
with lock:
did_seek = False
if req.seek_time is not None:
cursor = max(first_ns, min(last_ns, req.seek_time))
state["cursor"] = cursor
state["last_idx"] = state["seek_idx"] = index_at(cursor)
did_seek = True
if req.playback_speed and req.playback_speed > 0:
state["speed"] = req.playback_speed
if req.playback_command == PlaybackCommand.Play:
# Restarting from the end replays from the beginning.
if state["cursor"] >= last_ns:
state["cursor"] = first_ns
state["last_idx"] = state["seek_idx"] = 0
did_seek = True
state["status"] = PlaybackStatus.Playing
elif req.playback_command == PlaybackCommand.Pause:
state["status"] = PlaybackStatus.Paused
status, cursor, speed = state["status"], state["cursor"], state["speed"]
request_id = req.request_id or ""
return PlaybackState(status, cursor, speed, did_seek, request_id)
server = foxglove.start_server(
name=f"{dataset.repo_id}/episode_{episode_index}",
host=host,
port=port,
capabilities=[Capability.PlaybackControl, Capability.Time],
server_listener=_PlaybackListener(),
playback_time_range=(first_ns, last_ns),
)
def playback_loop() -> None:
# Cap how far the cursor may advance in a single tick. A slow frame decode (or any stall)
# would otherwise make ``dt`` huge and produce one enormous catch-up batch; clamping it makes
# playback trail wall-clock under a slow decoder while each tick emits a bounded frame range.
max_tick_dt_s = 0.25
prev = time.monotonic()
while not stop_event.is_set():
time.sleep(1.0 / 60.0)
ended = False
speed = 1.0
with lock:
now = time.monotonic()
dt = min(now - prev, max_tick_dt_s)
prev = now
# A queued seek is always serviced, even while paused, so scrubbing updates the view.
work = []
seek_idx = state["seek_idx"]
if seek_idx is not None:
state["seek_idx"] = None
work.append(seek_idx)
if state["status"] == PlaybackStatus.Playing:
cursor = state["cursor"] + int(dt * 1e9 * state["speed"])
start_idx = state["last_idx"] + 1
if cursor >= last_ns:
cursor, target, ended = last_ns, n_frames - 1, True
else:
target = index_at(cursor)
state["cursor"] = cursor
work.extend(range(start_idx, target + 1))
# cursor only grows while playing (seeks reset last_idx in the listener), so
# target >= last_idx here; a plain assignment is correct and clearer than max().
state["last_idx"] = target
if ended:
state["status"] = PlaybackStatus.Ended
if not work:
continue
cursor, speed = state["cursor"], state["speed"]
# Emit outside the lock; this is the only thread that calls emit_frame. Re-check
# stop_event between frames so shutdown stays responsive even mid-batch.
for i in work:
if stop_event.is_set():
break
emit_frame(i)
server.broadcast_time(cursor)
if ended:
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Ended, cursor, speed, False, ""))
# Emit the first frame so channels are advertised (done before the loop starts, so emission stays
# single-threaded). Late-connecting clients re-receive frames once they seek/play.
emit_frame(0)
with lock:
state["last_idx"] = 0
server.broadcast_time(first_ns)
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Paused, first_ns, 1.0, True, ""))
thread = threading.Thread(target=playback_loop, name="foxglove-playback", daemon=True)
thread.start()
print(f"Foxglove server running. Connect the Foxglove app to ws://{host}:{port}")
print("Use the playback controls in Foxglove to play/pause and scrub the episode. Ctrl-C to exit.")
try:
while not stop_event.is_set():
time.sleep(0.5)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
finally:
stop_event.set()
thread.join(timeout=2.0)
server.stop()
channels.clear()
# ── Backend-agnostic dispatch ─────────────────────────────────────────────
# These let callers select a visualization backend at runtime via a string
# (e.g. a `--display_mode` CLI flag) without branching on the backend everywhere.
def init_visualization(
display_mode: str,
+101
View File
@@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Foxglove backend's pure helpers.
These cover topic naming, series labelling and feature-name parsing. They import
``foxglove_visualization`` directly and need NO ``foxglove`` extra: the SDK is imported lazily inside
the functions that talk to the server, so the helpers below run in the base test tier.
"""
import numpy as np
from lerobot.utils import foxglove_visualization as fv
from lerobot.utils.constants import ACTION, OBS_STATE
def test_foxglove_safe_name_collapses_dots():
assert fv._foxglove_safe_name("observation.images.front") == "observation_images_front"
assert fv._foxglove_safe_name("plain") == "plain"
def test_foxglove_topic_image_strips_prefix_without_doubling_images():
# Fully-qualified camera key -> single clean segment (no doubled "images").
assert fv._foxglove_topic("observation.images.front", is_image=True) == "/observation/images/front"
# A nested camera name keeps its structure via safe-name collapsing.
assert (
fv._foxglove_topic("observation.images.wrist.left", is_image=True) == "/observation/images/wrist_left"
)
# Bare camera name (as real robots emit).
assert fv._foxglove_topic("front", is_image=True) == "/observation/images/front"
def test_foxglove_topic_scalar_sources():
assert fv._foxglove_topic(OBS_STATE) == "/observation/state"
assert fv._foxglove_topic("observation.environment_state") == "/observation/state"
assert fv._foxglove_topic(ACTION) == "/action/state"
assert fv._foxglove_topic("action.delta") == "/action/state"
def test_labeled_scalars_uses_labels_then_index_fallback():
assert fv._labeled_scalars("state", np.array([1.0, 2.0, 3.0])) == {
"state_0": 1.0,
"state_1": 2.0,
"state_2": 3.0,
}
assert fv._labeled_scalars("state", [1.0, 2.0], ["pan", "lift"]) == {"pan": 1.0, "lift": 2.0}
# Wrong-length labels fall back to index naming (never silently mislabels).
assert fv._labeled_scalars("q", [1.0, 2.0], ["only_one"]) == {"q_0": 1.0, "q_1": 2.0}
def test_frame_to_scalars_matches_live_labeling_and_handles_scalar():
frame = {OBS_STATE: np.array([1.0, 2.0])}
# No metadata -> {short_name}_{i}, identical to the live-stream fallback.
assert fv._frame_to_scalars(frame, OBS_STATE) == fv._labeled_scalars("state", np.array([1.0, 2.0]))
assert fv._frame_to_scalars(frame, OBS_STATE) == {"state_0": 1.0, "state_1": 2.0}
# Metadata labels are honored.
assert fv._frame_to_scalars(frame, OBS_STATE, ["pan", "lift"]) == {"pan": 1.0, "lift": 2.0}
# A 0-d scalar becomes a single entry named by the short feature name.
assert fv._frame_to_scalars({ACTION: np.array(5.0)}, ACTION) == {"action": 5.0}
# A missing feature yields an empty mapping.
assert fv._frame_to_scalars({}, OBS_STATE) == {}
def test_feature_dim_names_formats():
# Flat list of names.
assert fv._feature_dim_names({"shape": [2], "names": ["x", "y"]}) == ["x", "y"]
# Category mapping (dict of lists).
assert fv._feature_dim_names({"shape": [2], "names": {"motors": ["m0", "m1"]}}) == ["m0", "m1"]
# name -> index mapping (returned sorted by index).
assert fv._feature_dim_names({"shape": [2], "names": {"delta_x": 0, "delta_y": 1}}) == [
"delta_x",
"delta_y",
]
# Bool values must NOT be treated as an index map (bool is a subclass of int).
assert fv._feature_dim_names({"shape": [2], "names": {"a": True, "b": False}}) is None
# Mismatched length -> None (won't silently mislabel).
assert fv._feature_dim_names({"shape": [3], "names": ["x", "y"]}) is None
# Missing / absent names -> None.
assert fv._feature_dim_names(None) is None
assert fv._feature_dim_names({"shape": [2]}) is None
def test_is_scalar():
assert fv._is_scalar(1.0)
assert fv._is_scalar(np.float32(2.0))
assert fv._is_scalar(np.array(3.0)) # 0-d array
assert not fv._is_scalar(np.array([1.0, 2.0]))
assert not fv._is_scalar("x")
+310
View File
@@ -0,0 +1,310 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import sys
from types import SimpleNamespace
import numpy as np
import pytest
pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])")
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
@pytest.fixture
def mock_rerun(monkeypatch):
"""
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
depend on the real library. Also reload the module-under-test so it binds to
this mock `rr`.
"""
calls = []
blueprints = []
class DummyScalar:
def __init__(self, value):
# Scalars may be built from a single float or from a 1D array batch.
self.value = value
class DummyImage:
def __init__(self, arr):
self.arr = arr
def compress(self, *a, **k):
return self
class DummyDepthImage:
def __init__(self, arr, colormap=None):
self.arr = arr
self.colormap = colormap
def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs:
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
def dummy_send_blueprint(blueprint, *a, **k):
blueprints.append(blueprint)
# Mock the `rerun.blueprint` submodule used to build the layout.
dummy_rrb = SimpleNamespace(
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
kind="Spatial2DView", origin=origin, name=name
),
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
kind="TimeSeriesView", name=name, contents=contents
),
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
)
dummy_rr = SimpleNamespace(
__name__="rerun",
__package__="rerun",
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
Image=DummyImage,
DepthImage=DummyDepthImage,
components=SimpleNamespace(Colormap=SimpleNamespace(Viridis="viridis")),
log=dummy_log,
send_blueprint=dummy_send_blueprint,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
blueprint=dummy_rrb,
)
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
# Now import and reload the module under test, to bind to our rerun mock
import lerobot.utils.rerun_visualization as rv
importlib.reload(rv)
# Expose the reloaded module, the call recorder and the captured blueprints
yield rv, calls, blueprints
def _keys(calls):
"""Helper to extract just the keys logged to rr.log"""
return [k for (k, _obj, _kw) in calls]
def _obj_for(calls, key):
"""Find the first object logged under a given key."""
for k, obj, _kw in calls:
if k == key:
return obj
raise KeyError(f"Key {key} not found in calls: {calls}")
def _kwargs_for(calls, key):
for k, _obj, kw in calls:
if k == key:
return kw
raise KeyError(f"Key {key} not found in calls: {calls}")
def _views_by_kind(blueprint, kind):
"""Return the views of a given kind from the (single) blueprint's grid."""
return [v for v in blueprint.root.views if v.kind == kind]
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
rv, calls, blueprints = mock_rerun
# Build EnvTransition dict
obs = {
f"{OBS_STATE}.temperature": np.float32(25.0),
# CHW image should be converted to HWC for rr.Image
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
}
act = {
"action.throttle": 0.7,
# 1D array should be logged as a single Scalars batch under one entity path
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
}
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: act,
}
# Extract observation and action data from transition like in the real call sites
obs_data = transition.get(TransitionKey.OBSERVATION, {})
action_data = transition.get(TransitionKey.ACTION, {})
rv.log_rerun_data(observation=obs_data, action=action_data)
# We expect:
# - observation.state.temperature -> Scalars
# - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars
# - action.vector -> single Scalars batch (no per-element suffix)
expected_keys = {
f"{OBS_STATE}.temperature",
"observation.camera",
"action.throttle",
"action.vector",
}
assert set(_keys(calls)) == expected_keys
# Check scalar types and values
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar"
assert float(temp_obj.value) == pytest.approx(25.0)
throttle_obj = _obj_for(calls, "action.throttle")
assert type(throttle_obj).__name__ == "DummyScalar"
assert float(throttle_obj.value) == pytest.approx(0.7)
# 1D vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vector")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
# Check image handling: CHW -> HWC
img_obj = _obj_for(calls, "observation.camera")
assert type(img_obj).__name__ == "DummyImage"
assert img_obj.arr.shape == (10, 20, 3) # transposed
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
# A blueprint should have been built and sent exactly once, and cached on the function.
assert len(blueprints) == 1
assert rv.log_rerun_data.blueprint is blueprints[0]
bp = blueprints[0]
# One spatial view per image path
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.camera"}
# One time-series view each for observation and action scalars
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert set(ts_views) == {"observation", "action"}
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
rv, calls, blueprints = mock_rerun
# First dict without prefixes treated as observation
# Second dict without prefixes treated as action
obs_plain = {
"temp": 1.5,
# Already HWC image => should stay as-is
"img": np.zeros((5, 6, 3), dtype=np.uint8),
"none": None, # should be skipped
}
act_plain = {
"throttle": 0.3,
"vec": np.array([9, 8, 7], dtype=np.float32),
}
# Extract observation and action data from list like the old function logic did
# First dict was treated as observation, second as action
rv.log_rerun_data(observation=obs_plain, action=act_plain)
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
expected = {
"observation.temp",
"observation.img",
"action.throttle",
"action.vec",
}
logged = set(_keys(calls))
assert logged == expected
# Scalars
t = _obj_for(calls, "observation.temp")
assert type(t).__name__ == "DummyScalar"
assert float(t.value) == pytest.approx(1.5)
throttle = _obj_for(calls, "action.throttle")
assert type(throttle).__name__ == "DummyScalar"
assert float(throttle.value) == pytest.approx(0.3)
# Image stays HWC
img = _obj_for(calls, "observation.img")
assert type(img).__name__ == "DummyImage"
assert img.arr.shape == (5, 6, 3)
assert _kwargs_for(calls, "observation.img").get("static", False) is True
# Vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vec")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
# Blueprint sent once with the expected view layout
assert len(blueprints) == 1
bp = blueprints[0]
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.img"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
def test_log_rerun_data_kwargs_only(mock_rerun):
rv, calls, blueprints = mock_rerun
rv.log_rerun_data(
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
action={"action.a": 1.0},
)
keys = set(_keys(calls))
assert "observation.temp" in keys
assert "observation.gray" in keys
assert "action.a" in keys
temp = _obj_for(calls, "observation.temp")
assert type(temp).__name__ == "DummyScalar"
assert float(temp.value) == pytest.approx(10.0)
img = _obj_for(calls, "observation.gray")
assert type(img).__name__ == "DummyDepthImage" # single-channel -> DepthImage
assert img.arr.shape == (8, 8, 1) # remains HWC
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert float(a.value) == pytest.approx(1.0)
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
assert len(blueprints) == 1
bp = blueprints[0]
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.a"]
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
"""The blueprint is built from the first call and not resent on subsequent calls."""
rv, calls, blueprints = mock_rerun
rv.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
assert len(blueprints) == 1
first_blueprint = rv.log_rerun_data.blueprint
rv.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
# Still only one blueprint, and the cached one is unchanged.
assert len(blueprints) == 1
assert rv.log_rerun_data.blueprint is first_blueprint
+13 -287
View File
@@ -14,297 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import sys
from types import SimpleNamespace
"""Tests for the backend-agnostic visualization dispatch.
These exercise the display-mode routing/validation only; they need neither ``rerun`` nor
``foxglove`` installed since the unknown-mode branch raises before touching any backend. Backend
behavior is covered in ``test_rerun_visualization.py`` and ``test_foxglove_visualization.py``.
"""
import numpy as np
import pytest
pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])")
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
from lerobot.utils import visualization_utils as vu
@pytest.fixture
def mock_rerun(monkeypatch):
"""
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
depend on the real library. Also reload the module-under-test so it binds to
this mock `rr`.
"""
calls = []
blueprints = []
class DummyScalar:
def __init__(self, value):
# Scalars may be built from a single float or from a 1D array batch.
self.value = value
class DummyImage:
def __init__(self, arr):
self.arr = arr
def compress(self, *a, **k):
return self
class DummyDepthImage:
def __init__(self, arr, colormap=None):
self.arr = arr
self.colormap = colormap
def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs:
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
def dummy_send_blueprint(blueprint, *a, **k):
blueprints.append(blueprint)
# Mock the `rerun.blueprint` submodule used to build the layout.
dummy_rrb = SimpleNamespace(
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
kind="Spatial2DView", origin=origin, name=name
),
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
kind="TimeSeriesView", name=name, contents=contents
),
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
)
dummy_rr = SimpleNamespace(
__name__="rerun",
__package__="rerun",
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
Image=DummyImage,
DepthImage=DummyDepthImage,
components=SimpleNamespace(Colormap=SimpleNamespace(Viridis="viridis")),
log=dummy_log,
send_blueprint=dummy_send_blueprint,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
blueprint=dummy_rrb,
)
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
# Now import and reload the module under test, to bind to our rerun mock
import lerobot.utils.visualization_utils as vu
importlib.reload(vu)
# Expose the reloaded module, the call recorder and the captured blueprints
yield vu, calls, blueprints
def test_visualization_modes():
assert vu.VISUALIZATION_MODES == ("rerun", "foxglove")
def _keys(calls):
"""Helper to extract just the keys logged to rr.log"""
return [k for (k, _obj, _kw) in calls]
def _obj_for(calls, key):
"""Find the first object logged under a given key."""
for k, obj, _kw in calls:
if k == key:
return obj
raise KeyError(f"Key {key} not found in calls: {calls}")
def _kwargs_for(calls, key):
for k, _obj, kw in calls:
if k == key:
return kw
raise KeyError(f"Key {key} not found in calls: {calls}")
def _views_by_kind(blueprint, kind):
"""Return the views of a given kind from the (single) blueprint's grid."""
return [v for v in blueprint.root.views if v.kind == kind]
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
vu, calls, blueprints = mock_rerun
# Build EnvTransition dict
obs = {
f"{OBS_STATE}.temperature": np.float32(25.0),
# CHW image should be converted to HWC for rr.Image
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
}
act = {
"action.throttle": 0.7,
# 1D array should be logged as a single Scalars batch under one entity path
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
}
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: act,
}
# Extract observation and action data from transition like in the real call sites
obs_data = transition.get(TransitionKey.OBSERVATION, {})
action_data = transition.get(TransitionKey.ACTION, {})
vu.log_rerun_data(observation=obs_data, action=action_data)
# We expect:
# - observation.state.temperature -> Scalars
# - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars
# - action.vector -> single Scalars batch (no per-element suffix)
expected_keys = {
f"{OBS_STATE}.temperature",
"observation.camera",
"action.throttle",
"action.vector",
}
assert set(_keys(calls)) == expected_keys
# Check scalar types and values
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar"
assert float(temp_obj.value) == pytest.approx(25.0)
throttle_obj = _obj_for(calls, "action.throttle")
assert type(throttle_obj).__name__ == "DummyScalar"
assert float(throttle_obj.value) == pytest.approx(0.7)
# 1D vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vector")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
# Check image handling: CHW -> HWC
img_obj = _obj_for(calls, "observation.camera")
assert type(img_obj).__name__ == "DummyImage"
assert img_obj.arr.shape == (10, 20, 3) # transposed
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
# A blueprint should have been built and sent exactly once, and cached on the function.
assert len(blueprints) == 1
assert vu.log_rerun_data.blueprint is blueprints[0]
bp = blueprints[0]
# One spatial view per image path
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.camera"}
# One time-series view each for observation and action scalars
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert set(ts_views) == {"observation", "action"}
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
vu, calls, blueprints = mock_rerun
# First dict without prefixes treated as observation
# Second dict without prefixes treated as action
obs_plain = {
"temp": 1.5,
# Already HWC image => should stay as-is
"img": np.zeros((5, 6, 3), dtype=np.uint8),
"none": None, # should be skipped
}
act_plain = {
"throttle": 0.3,
"vec": np.array([9, 8, 7], dtype=np.float32),
}
# Extract observation and action data from list like the old function logic did
# First dict was treated as observation, second as action
vu.log_rerun_data(observation=obs_plain, action=act_plain)
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
expected = {
"observation.temp",
"observation.img",
"action.throttle",
"action.vec",
}
logged = set(_keys(calls))
assert logged == expected
# Scalars
t = _obj_for(calls, "observation.temp")
assert type(t).__name__ == "DummyScalar"
assert float(t.value) == pytest.approx(1.5)
throttle = _obj_for(calls, "action.throttle")
assert type(throttle).__name__ == "DummyScalar"
assert float(throttle.value) == pytest.approx(0.3)
# Image stays HWC
img = _obj_for(calls, "observation.img")
assert type(img).__name__ == "DummyImage"
assert img.arr.shape == (5, 6, 3)
assert _kwargs_for(calls, "observation.img").get("static", False) is True
# Vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vec")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
# Blueprint sent once with the expected view layout
assert len(blueprints) == 1
bp = blueprints[0]
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.img"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
def test_log_rerun_data_kwargs_only(mock_rerun):
vu, calls, blueprints = mock_rerun
vu.log_rerun_data(
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
action={"action.a": 1.0},
)
keys = set(_keys(calls))
assert "observation.temp" in keys
assert "observation.gray" in keys
assert "action.a" in keys
temp = _obj_for(calls, "observation.temp")
assert type(temp).__name__ == "DummyScalar"
assert float(temp.value) == pytest.approx(10.0)
img = _obj_for(calls, "observation.gray")
assert type(img).__name__ == "DummyDepthImage" # single-channel -> DepthImage
assert img.arr.shape == (8, 8, 1) # remains HWC
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert float(a.value) == pytest.approx(1.0)
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
assert len(blueprints) == 1
bp = blueprints[0]
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.a"]
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
"""The blueprint is built from the first call and not resent on subsequent calls."""
vu, calls, blueprints = mock_rerun
vu.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
assert len(blueprints) == 1
first_blueprint = vu.log_rerun_data.blueprint
vu.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
# Still only one blueprint, and the cached one is unchanged.
assert len(blueprints) == 1
assert vu.log_rerun_data.blueprint is first_blueprint
@pytest.mark.parametrize("func", ["init_visualization", "log_visualization_data", "shutdown_visualization"])
def test_dispatch_rejects_unknown_mode(func):
with pytest.raises(ValueError, match="Unknown display_mode"):
getattr(vu, func)("bogus")