Compare commits

..

7 Commits

Author SHA1 Message Date
CarolinePascal 6d8ef7dc60 fix(autocast): route inference autocasts through safe helper
Apply get_safe_autocast_context to the control_utils and sync inference
paths for uniformity with lerobot_eval. AMP is now enabled on any
AMP-capable device (cuda, xpu, cpu) when use_amp is set, and stays a
no-op on mps.
2026-07-03 13:22:30 +02:00
CarolinePascal ca6d764107 fix(autocast): gate autocast on AMP-capable devices
Add get_safe_autocast_context helper that only enters torch.autocast on
devices supporting AMP (cuda, xpu, cpu) and falls back to a no-op on mps
and unknown backends. Route the previously unconditional/underspecified
autocasts (vla_jepa, groot, molmoact2, lerobot_eval) through it so
autocast can be requested unconditionally without breaking on unsupported
devices.
2026-07-03 11:22:33 +02:00
Pepijn 07285677a3 fix(train): drive Accelerate mixed precision from policy.dtype (#3912)
* fix(train): drive Accelerate mixed precision from policy.dtype

`accelerator.autocast()` was always a no-op because `mixed_precision`
was never set, so `--policy.dtype=bfloat16` only cast the model params
(via the policy) while autocast-eligible ops still ran in fp32/tf32.

Map the active policy's `dtype` onto Accelerate's `mixed_precision`
(bfloat16 -> bf16, float16 -> fp16, float32 -> no) so autocast is active
for bf16/fp16 and stays full precision for float32. Policies without a
string `dtype` field fall back to Accelerate's launcher default, so
existing behavior is preserved.

* style(train): condense mixed-precision comment to one line
2026-07-02 19:15:19 +02:00
Caroline Pascal 7ae12124b0 fix(save codec options): making sure codec options are always set via set_if (#3910)
* fix(save codec options): making sure codec options are always safely set through `set_if`

* tests(update): updating tests
2026-07-02 15:29:14 +02:00
Caroline Pascal c746ca2df2 fix(depth unit): adding input depth unit storage in the dataset metadata (#3899)
* fix(depth unit): storing raw depth units in the dataset metadata for correct depth statistics and depth raw frames handling. The unit is stored as a string ("m","mm") under "depth_unit" at the same level as "is_depth_map". Unit is inferred from the depth frame type.

* feat(raw frame unit): adapting dataset reader so that raw depth frames are scaled according to the requested unit

* feat(stats units): rescaling stats when loading a dataset so that the stats are given in the requested unit

* tests(unit): adapting and extending depth tests to units manipulations

* chore(format): formating code

* feat(warning): adding a warning when depth unit is not specified in the dataset

* chore(infer_depth_unit): moving the depth unit inference utility in a more accessible location

* feat(rerun unit): adding correct depth unit display for rerun (foxglove does not support units yet)

* feat(unit getter): adding a proper output_depth_unit getter to LeRobotDataset for cleaner integration

* fix(streaming dataset): extending support for depth units to streaming datasets

* test(rerun): fixing rerun tests
2026-07-02 11:53:13 +02:00
Caroline Pascal b961d2a8c5 feat(libaom-av1): adding support for libaom-av1 codec (#3898) 2026-07-02 11:03:41 +02:00
Steven Palma 052d329470 feat(visualization): add foxglove support (#3902)
* Add Foxglove display mode for teleoperate

Add a --display_mode flag (rerun|foxglove) to lerobot-teleoperate. When set
to foxglove, stream observations/actions over a Foxglove WebSocket server:
images as RawImage/CompressedImage, scalars as typed JSON channels with
schemas generated from the feature names (sanitized so paths don't need
quoting). Adds a `foxglove` extra.

* Add Foxglove display mode to lerobot-record

Wire the --display_mode flag (rerun|foxglove) into lerobot-record, matching
lerobot-teleoperate: route init/log through the backend-agnostic dispatchers
and stop the visualization backend on exit.

* update foxglove-sdk to 0.25.1

* Use static lerobot.Scalars schema for Foxglove state topics

Replace the per-topic JSON schema derived from feature names with a single
static lerobot.Scalars schema: a scalars array of {label, value} objects. The
same schema fits any robot regardless of which observation/action features it
reports, and the label field lets Foxglove name each series automatically so
one filtered path plots every feature.

* add foxglove option to dataset viz

* Make Foxglove dataset playback loop the sole frame emitter

Address review: the listener no longer emits frames, it only mutates
playback state and queues a one-shot seek index that the playback loop
services. The loop is now the only caller of emit_frame, so concurrent
random access into the on-disk dataset / video decoder never overlaps.

Also remove the dead server_holder and tighten the _foxglove_safe_name
docstring to state what it does and why.

* Label Foxglove dataset scalars with feature dimension names

Use the dataset's per-dimension feature names (e.g. joint names) as the
Foxglove series labels for /observation/state and /action/state instead
of bare indices. LeRobot stores `names` inconsistently (flat list,
{category: [...]}, or {name: index}), so _feature_dim_names handles each
and falls back to indices on any unknown format or length mismatch.

* Make Foxglove server host bindable and refactor topic/channel handling

Pass display_ip through as the Foxglove WebSocket bind host (127.0.0.1
for local only, 0.0.0.0 for all interfaces) instead of always binding
locally. In lerobot-dataset-viz, fold the separate --port into --web-port
so one flag covers both the Rerun web viewer and the Foxglove server port.

Add a _foxglove_topic() helper and thread a per-topic channel cache
through the log helpers so dataset playback stays self-contained instead
of mutating the module-global cache. Promote SUCCESS to constants.py.

* feat(viz): add support for foxglove in rollout + add to viz tag

* fix(docs): remove misleading installation note

* fix(visualization): no duplicated prefix, consolidated norm + warnings log

* chore(viz): minor improvements

* refactor(viz): split files + autoplay + updated docs + added minimal tests

* fix(viz): right tags + warning

* feat(deprecated ws-port): removing rerun's depreacted ws-port parameter in dataset visualization

* chore(web ports): adding global variables for default foxglove/rerun web ports

* feat(depth): adding depth support to foxglove visualizer. Because of foxglove limitations (min and max values on RawImage cannot be set from the SDK), depth is normalized between [0,1] when a depth range is provided.

* fix(rerun depth range): making rerun depth range computation safe against missing stats

* chore(foxglove depth): make it simple, and make it work.

* fix(scaling): fixing depth frames scaling

---------

Co-authored-by: Roman Shtylman <roman@foxglove.dev>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-07-01 18:39:32 +02:00
43 changed files with 1863 additions and 593 deletions
+5 -5
View File
@@ -126,7 +126,7 @@ import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
from lerobot.utils.visualization_utils import init_visualization, log_visualization_data, shutdown_visualization
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
@@ -142,7 +142,7 @@ teleop_config = SO101LeaderConfig(
id="my_leader_arm",
)
init_rerun(session_name="teleoperation")
init_visualization("rerun", session_name="teleoperation") # pass "foxglove" to stream to Foxglove instead
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
@@ -158,7 +158,7 @@ while True:
observation = robot.get_observation()
action = teleop_device.get_action()
robot.send_action(action)
log_rerun_data(observation=observation, action=action)
log_visualization_data("rerun", observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
@@ -223,7 +223,7 @@ from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_visualization
from lerobot.scripts.lerobot_record import record_loop
from lerobot.processor import make_default_processors
@@ -270,7 +270,7 @@ def main():
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_rerun(session_name="recording")
init_visualization("rerun", session_name="recording")
# Connect the robot and teleoperator
robot.connect()
+2
View File
@@ -265,6 +265,8 @@ lerobot-dataset-viz \
Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode.
To use [Foxglove](https://foxglove.dev) instead of Rerun, install the extra add `--display-mode foxglove`. This starts a WebSocket server (connect the Foxglove app to `ws://127.0.0.1:8765`) that serves the episode as a seekable timeline you can play/pause and scrub.
For advanced usage—including visualizing datasets stored on a remote server—run:
```bash
@@ -134,9 +134,6 @@ lerobot-train \
> [!TIP]
> This is purely a decode-time presentation choice — it does **not** alter the stored video or its metadata, so the same dataset can be read as `mm` or `m` without re-encoding. It has no effect on datasets without depth cameras.
> [!IMPORTANT]
> Depth statistics in `meta/stats.json` are always computed in **millimetres**, regardless of the raw frame dtype.
---
## Persistence in dataset metadata
+1
View File
@@ -125,6 +125,7 @@ hardware = [
]
viz = [
"rerun-sdk>=0.24.0,<0.34.0",
"foxglove-sdk>=0.25.1,<0.26.0",
]
# ── User-facing composite extras (map to CLI scripts) ─────
# lerobot-record, lerobot-replay, lerobot-calibrate, lerobot-teleoperate, etc.
+2 -2
View File
@@ -18,7 +18,6 @@ from __future__ import annotations
# Utilities
########################################################################################
import time
from contextlib import nullcontext
from copy import copy
from typing import TYPE_CHECKING, Any
@@ -26,6 +25,7 @@ import numpy as np
import torch
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _deepdiff_available, require_package
if TYPE_CHECKING or _deepdiff_available:
@@ -76,7 +76,7 @@ def predict_action(
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
get_safe_autocast_context(device, enabled=use_amp),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
observation = prepare_observation_for_inference(observation, device, task, robot_type)
+6
View File
@@ -34,6 +34,8 @@ from .types import (
)
from .video import (
DEFAULT_DEPTH_UNIT,
DEPTH_METER_UNIT,
DEPTH_MILLIMETER_UNIT,
VALID_VIDEO_CODECS,
VIDEO_ENCODER_INFO_KEYS,
DepthEncoderConfig,
@@ -41,6 +43,7 @@ from .video import (
VideoEncoderConfig,
depth_encoder_defaults,
encoder_config_from_video_info,
infer_depth_unit,
rgb_encoder_defaults,
)
@@ -70,8 +73,11 @@ __all__ = [
"depth_encoder_defaults",
# Factories
"encoder_config_from_video_info",
"infer_depth_unit",
# Constants
"DEFAULT_DEPTH_UNIT",
"DEPTH_METER_UNIT",
"DEPTH_MILLIMETER_UNIT",
"VALID_VIDEO_CODECS",
"VIDEO_ENCODER_INFO_KEYS",
]
+24 -5
View File
@@ -22,6 +22,8 @@ import logging
from dataclasses import dataclass, field
from typing import Any, ClassVar, Self
import numpy as np
from lerobot.utils.import_utils import require_package
logger = logging.getLogger(__name__)
@@ -36,7 +38,9 @@ HW_VIDEO_CODECS = [
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
{"h264", "hevc", "libsvtav1", "libaom-av1", "auto", *HW_VIDEO_CODECS}
)
# Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
@@ -65,6 +69,15 @@ DEPTH_METER_UNIT: str = "m"
DEPTH_MILLIMETER_UNIT: str = "mm"
DEFAULT_DEPTH_UNIT: str = DEPTH_MILLIMETER_UNIT
def infer_depth_unit(dtype: np.dtype | type) -> str:
"""Infer the physical unit of raw depth frames from their dtype.
Floating-point frames are assumed to be in metres, integer frames in millimetres.
"""
return DEPTH_METER_UNIT if np.issubdtype(np.dtype(dtype), np.floating) else DEPTH_MILLIMETER_UNIT
# Depth-specific tuning fields persisted under ``features[*]["info"]`` as ``video.<name>``.
DEPTH_ENCODER_INFO_FIELD_NAMES: frozenset[str] = frozenset({"depth_min", "depth_max", "shift", "use_log"})
@@ -213,18 +226,24 @@ class VideoEncoderConfig:
if encoder_threads is not None:
svtav1_parts.append(f"lp={encoder_threads}")
if svtav1_parts:
opts["svtav1-params"] = ":".join(svtav1_parts)
set_if("svtav1-params", ":".join(svtav1_parts))
elif self.vcodec in ("h264", "hevc"):
set_if("crf", self.crf)
set_if("preset", self.preset)
if self.fast_decode:
opts["tune"] = "fastdecode"
set_if("tune", "fastdecode")
set_if("threads", encoder_threads)
elif self.vcodec == "libaom-av1":
set_if("crf", self.crf)
set_if("preset", self.preset)
if encoder_threads is not None:
set_if("threads", encoder_threads)
set_if("row-mt", 1)
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
if self.crf is not None:
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
set_if("q:v", max(1, min(100, 100 - self.crf * 2)))
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
opts["rc"] = 0
set_if("rc", 0)
set_if("qp", self.crf)
set_if("preset", self.preset)
elif self.vcodec == "h264_vaapi":
+5 -11
View File
@@ -22,7 +22,6 @@ import numpy as np
from lerobot.processor import RelativeActionsProcessorStep
from lerobot.utils.constants import ACTION, OBS_STATE
from .depth_utils import MM_PER_METRE
from .io_utils import load_image_as_numpy
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
@@ -509,8 +508,8 @@ def compute_episode_stats(
Note:
For 'image'/'video' features, stats are computed per channel and kept with a
leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) are
instead canonicalized to millimetres regardless of the raw frame unit.
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip
this rescaling and remain in their stored units (stored in ``depth_unit``).
"""
if quantile_list is None:
quantile_list = DEFAULT_QUANTILES
@@ -534,14 +533,9 @@ def compute_episode_stats(
)
if features[key]["dtype"] in ["image", "video"]:
if (features[key].get("info") or {}).get("is_depth_map", False):
# Depth stats are canonically stored in millimetres; metre (float) depth is
# scaled up, integer (millimetre) depth is left as-is.
normalization_factor = (
1.0 / MM_PER_METRE if np.issubdtype(ep_ft_array.dtype, np.floating) else 1.0
)
else:
normalization_factor = 255.0
normalization_factor = (
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
)
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
for k, v in ep_stats[key].items()
+31 -1
View File
@@ -26,12 +26,13 @@ import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
from lerobot.configs import VideoEncoderConfig
from lerobot.configs import DEPTH_METER_UNIT, VideoEncoderConfig
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
from lerobot.utils.feature_utils import _validate_feature_names
from lerobot.utils.utils import flatten_dict
from .compute_stats import aggregate_stats
from .depth_utils import MM_PER_METRE
from .feature_utils import create_empty_dataset_info
from .io_utils import (
get_file_size_in_mb,
@@ -358,6 +359,35 @@ class LeRobotDatasetMetadata:
return [key for key, ft in self.features.items() if _is_depth(ft)]
def rescale_depth_stats(self, output_unit: str) -> None:
"""Rescale depth feature stats in place from their recorded unit to ``output_unit``.
Depth stats are stored in the unit the frames were recorded in
(``features[key]["info"]["depth_unit"]``), while frames are returned in
``output_unit`` on read. This converts the unit-bearing stat entries so
stats match the frames consumers see.
"""
missing_unit_keys = [
key for key in self.depth_keys if (self.features[key].get("info") or {}).get("depth_unit") is None
]
if missing_unit_keys:
logging.warning(
f"Depth feature(s) {missing_unit_keys} have no recorded 'depth_unit' in their info. "
f"Depth maps and stats for these keys will be returned AS IS, with no unit conversion "
f"to the requested output unit {output_unit!r}. Re-record the dataset or set 'depth_unit' "
f"in the feature info (meta/info.json) to enable conversion."
)
if self.stats is None:
return
for key in self.depth_keys:
stored_unit = (self.features[key].get("info") or {}).get("depth_unit")
if stored_unit is None or stored_unit == output_unit or key not in self.stats:
continue
factor = MM_PER_METRE if stored_unit == DEPTH_METER_UNIT else 1.0 / MM_PER_METRE
self.stats[key] = {
stat: value if stat == "count" else value * factor for stat, value in self.stats[key].items()
}
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
+20 -2
View File
@@ -22,10 +22,14 @@ from pathlib import Path
import datasets
import torch
from lerobot.configs import DEFAULT_DEPTH_UNIT, DepthEncoderConfig
from lerobot.configs import (
DEFAULT_DEPTH_UNIT,
DEPTH_METER_UNIT,
DepthEncoderConfig,
)
from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .depth_utils import MM_PER_METRE, dequantize_depth
from .feature_utils import (
check_delta_timestamps,
get_delta_indices,
@@ -102,6 +106,13 @@ class DatasetReader:
for vid_key in self._meta.depth_keys
}
# Get the input unit of each depth feature stored as raw images.
self._image_depth_units: dict[str, str | None] = {
key: (self._meta.features[key].get("info") or {}).get("depth_unit")
for key in self._meta.depth_keys
if key in self._meta.image_keys
}
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
@@ -329,6 +340,13 @@ class DatasetReader:
continue
item[cam] = self._image_transforms(item[cam])
# Convert depth features to the output unit.
for key, stored_unit in self._image_depth_units.items():
if key in item and stored_unit is not None and stored_unit != self._depth_output_unit:
item[key] = (
item[key] * MM_PER_METRE if stored_unit == DEPTH_METER_UNIT else item[key] / MM_PER_METRE
)
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
+10
View File
@@ -36,6 +36,7 @@ from lerobot.configs import (
RGBEncoderConfig,
VideoEncoderConfig,
depth_encoder_defaults,
infer_depth_unit,
rgb_encoder_defaults,
)
@@ -209,6 +210,15 @@ class DatasetWriter:
self.episode_buffer["timestamp"].append(timestamp)
self.episode_buffer["task"].append(frame.pop("task"))
# Record each depth feature's input unit once, inferred from the first frame's dtype.
if frame_index == 0:
for depth_key in self._meta.depth_keys:
if depth_key not in frame:
continue
info = self._meta.features[depth_key].setdefault("info", {})
if info.get("depth_unit") is None:
info["depth_unit"] = infer_depth_unit(np.asarray(frame[depth_key]).dtype)
# Start streaming encoder on first frame of episode
if frame_index == 0 and self._streaming_encoder is not None:
self._streaming_encoder.start_episode(
+2 -5
View File
@@ -34,6 +34,7 @@ from lerobot.configs.video import (
DEPTH_METER_UNIT,
DEPTH_MILLIMETER_UNIT,
DEPTH_QMAX,
infer_depth_unit,
)
from .image_writer import squeeze_single_channel
@@ -57,11 +58,7 @@ def _depth_input_to_float32_and_unit(
input_unit: Literal["auto", DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT],
) -> tuple[NDArray[np.float32], Literal[DEPTH_METER_UNIT, DEPTH_MILLIMETER_UNIT]]:
"""Convert depth to float32 in the chosen unit, and return the resolved unit."""
resolved_unit = (
(DEPTH_METER_UNIT if np.issubdtype(depth.dtype, np.floating) else DEPTH_MILLIMETER_UNIT)
if input_unit == "auto"
else input_unit
)
resolved_unit = infer_depth_unit(depth.dtype) if input_unit == "auto" else input_unit
return depth.astype(np.float32, order="K"), resolved_unit
+6
View File
@@ -224,6 +224,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
self.root = self.meta.root
self.revision = self.meta.revision
self.meta.rescale_depth_stats(self._depth_output_unit)
if episodes is not None and any(
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
@@ -350,6 +351,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Frames per second used during data collection."""
return self.meta.fps
@property
def depth_output_unit(self) -> str:
"""Physical unit (``"m"`` or ``"mm"``) depth maps and statistics are returned in on read."""
return self._depth_output_unit
@property
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
+24 -2
View File
@@ -22,11 +22,11 @@ import numpy as np
import torch
from datasets import load_dataset
from lerobot.configs import DEFAULT_DEPTH_UNIT, DepthEncoderConfig
from lerobot.configs import DEFAULT_DEPTH_UNIT, DEPTH_METER_UNIT, DepthEncoderConfig
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .depth_utils import MM_PER_METRE, dequantize_depth
from .feature_utils import get_delta_indices
from .io_utils import item_to_torch
from .utils import (
@@ -310,6 +310,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
)
self.root = self.meta.root
self.revision = self.meta.revision
self.meta.rescale_depth_stats(self._depth_output_unit)
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
@@ -318,6 +319,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
for vid_key in self.meta.depth_keys
}
# Input unit of each depth feature stored as raw images (dequantized separately from videos).
self._image_depth_units: dict[str, str | None] = {
key: (self.meta.features[key].get("info") or {}).get("depth_unit")
for key in self.meta.depth_keys
if key in self.meta.image_keys
}
self.delta_timestamps = None
self.delta_indices = None
@@ -348,6 +356,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
def fps(self):
return self.meta.fps
@property
def depth_output_unit(self) -> str:
"""Physical unit (``"m"`` or ``"mm"``) depth maps are returned in on read."""
return self._depth_output_unit
@staticmethod
def _iter_random_indices(
rng: np.random.Generator, buffer_size: int, random_batch_size=100
@@ -530,6 +543,15 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
for update in updates:
result.update(update)
# Convert raw-image depth features to the output unit (video depth is already converted).
for key, stored_unit in self._image_depth_units.items():
if key in result and stored_unit is not None and stored_unit != self._depth_output_unit:
result[key] = (
result[key] * MM_PER_METRE
if stored_unit == DEPTH_METER_UNIT
else result[key] / MM_PER_METRE
)
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
yield result
+1 -4
View File
@@ -47,7 +47,7 @@ from lerobot.configs import (
)
from lerobot.utils.import_utils import get_safe_default_video_backend
from .depth_utils import MM_PER_METRE, quantize_depth
from .depth_utils import quantize_depth
from .pyav_utils import get_pix_fmt_channels
logger = logging.getLogger(__name__)
@@ -848,9 +848,6 @@ class _CameraEncoderThread(threading.Thread):
# Reshape CHW to (H*W, C) for per-channel stats
channels = img_downsampled.shape[0]
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
# Depth stats are canonically stored in millimetres; metre (float) depth is scaled up.
if self.is_depth and np.issubdtype(frame_data.dtype, np.floating):
img_for_stats = img_for_stats * MM_PER_METRE
stats_tracker.update(img_for_stats)
frame_count += 1
+3 -2
View File
@@ -43,6 +43,7 @@ from torch import Tensor
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
@@ -243,7 +244,7 @@ class GrootPolicy(PreTrainedPolicy):
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
with get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.forward(groot_inputs)
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
@@ -275,7 +276,7 @@ class GrootPolicy(PreTrainedPolicy):
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
with get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
@@ -31,7 +31,6 @@ import logging
import os
import types
from collections import deque
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any
import numpy as np
@@ -43,6 +42,7 @@ from torch.distributions import Beta
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
from ..rtc.modeling_rtc import RTCProcessor
@@ -1644,10 +1644,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
device=device,
)
action_dim = self._output_action_dim(batch)
autocast_context = (
torch.autocast(device_type=device.type, dtype=model_dtype)
if device.type in {"cuda", "cpu"} and model_dtype in {torch.bfloat16, torch.float16}
else nullcontext()
autocast_context = get_safe_autocast_context(
device, dtype=model_dtype, enabled=model_dtype in {torch.bfloat16, torch.float16}
)
with autocast_context:
if inference_action_mode == "discrete":
@@ -26,6 +26,7 @@ from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.device_utils import get_safe_autocast_context
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
@@ -183,7 +184,7 @@ class VLAJEPAModel(nn.Module):
action_idx = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
with get_safe_autocast_context(device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_idx[0], embodied_idx[1], :].view(b, -1, h)
@@ -250,7 +251,7 @@ class VLAJEPAModel(nn.Module):
) -> Tensor:
"""Flow-matching action-head loss, repeated over `repeated_diffusion_steps`."""
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.float32):
with get_safe_autocast_context(device_type, dtype=torch.float32):
r = self.config.repeated_diffusion_steps
horizon = self.config.chunk_size
actions_target = actions[:, -horizon:, :].to(torch.float32).repeat(r, 1, 1)
+6 -3
View File
@@ -226,11 +226,14 @@ class RolloutConfig:
device: str | None = None
task: str = ""
display_data: bool = False
# Display data on a remote Rerun server
# Visualization backend used when display_data is True: "rerun" or "foxglove".
display_mode: str = "rerun"
# For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket
# server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces).
display_ip: str | None = None
# Port of the remote Rerun server
# For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to.
display_port: int | None = None
# Whether to display compressed images in Rerun
# Whether to display compressed (JPEG) images instead of raw frames
display_compressed_images: bool = False
# Use vocal synthesis to read events
play_sounds: bool = True
+2 -6
View File
@@ -17,7 +17,6 @@
from __future__ import annotations
import logging
from contextlib import nullcontext
from copy import copy
import torch
@@ -25,6 +24,7 @@ import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.device_utils import get_safe_autocast_context
from .base import InferenceEngine
@@ -102,11 +102,7 @@ class SyncInferenceEngine(InferenceEngine):
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
# tensor/array values are not shared with any other reader.
observation = copy(obs_frame)
autocast_ctx = (
torch.autocast(device_type=self._device.type)
if self._device.type == "cuda" and self._policy.config.use_amp
else nullcontext()
)
autocast_ctx = get_safe_autocast_context(self._device, enabled=self._policy.config.use_amp)
with torch.inference_mode(), autocast_ctx:
observation = prepare_observation_for_inference(
observation, self._device, self._task, self._robot_type
+4 -3
View File
@@ -26,7 +26,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import log_rerun_data
from lerobot.utils.visualization_utils import log_visualization_data
from ..inference import InferenceEngine
@@ -162,11 +162,12 @@ class RolloutStrategy(abc.ABC):
action_dict: dict | None,
runtime_ctx: RuntimeContext,
) -> None:
"""Log observation/action telemetry to Rerun if display_data is enabled."""
"""Log observation/action telemetry to the visualization backend if display_data is enabled."""
cfg = runtime_ctx.cfg
if not cfg.display_data:
return
log_rerun_data(
log_visualization_data(
cfg.display_mode,
observation=obs_processed,
action=action_dict,
compress_images=cfg.display_compressed_images,
+5 -2
View File
@@ -44,7 +44,7 @@ from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import log_rerun_data
from lerobot.utils.visualization_utils import log_visualization_data
from ..configs import EpisodicStrategyConfig
from ..context import RolloutContext
@@ -171,6 +171,7 @@ class EpisodicStrategy(RolloutStrategy):
fps=fps,
control_time_s=reset_time_s,
display_data=cfg.display_data,
display_mode=cfg.display_mode,
display_compressed=display_compressed,
)
@@ -259,6 +260,7 @@ class EpisodicStrategy(RolloutStrategy):
fps: float,
control_time_s: float,
display_data: bool,
display_mode: str,
display_compressed: bool,
) -> None:
"""Reset-phase loop: teleop drives the robot if available, no recording."""
@@ -288,7 +290,8 @@ class EpisodicStrategy(RolloutStrategy):
if display_data:
obs_processed = processors.robot_observation_processor(obs)
log_rerun_data(
log_visualization_data(
display_mode,
observation=obs_processed,
action=act_teleop,
compress_images=display_compressed,
+105 -32
View File
@@ -59,6 +59,18 @@ distant$ lerobot-dataset-viz \
local$ rerun rerun+http://IP:GRPC_PORT/proxy
```
- Visualize data in Foxglove with a seekable, scrubbable timeline:
```
local$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--display-mode foxglove
# then open the Foxglove app and connect to ws://127.0.0.1:8765
```
This starts a Foxglove WebSocket server that serves the episode on demand from the on-disk dataset,
so you can play/pause and scrub anywhere in the episode using Foxglove's playback controls.
"""
import argparse
@@ -72,10 +84,14 @@ import torch
import torch.utils.data
import tqdm
from lerobot.configs import DEPTH_MILLIMETER_UNIT
from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD, SUCCESS
from lerobot.utils.utils import init_logging
DEFAULT_FOXGLOVE_PORT = 8765
DEFAULT_RERUN_PORT = 9090
def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]:
"""Return per-dimension names for a feature from the dataset metadata.
@@ -108,6 +124,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
return hwc_uint8_numpy
def to_hwc_float32_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
check_chw_float32(chw_float32_torch)
hwc_float32_numpy = chw_float32_torch.permute(1, 2, 0).numpy()
return hwc_float32_numpy
def build_blueprint_from_dataset(dataset: LeRobotDataset):
"""Build a Rerun blueprint laying out camera images and time series for the given dataset.
@@ -126,32 +148,43 @@ def build_blueprint_from_dataset(dataset: LeRobotDataset):
names = get_feature_names(dataset, key)
styling = rr.SeriesLines(names=names)
views.append(rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling}))
for key in (DONE, REWARD, "next.success"):
for key in (DONE, REWARD, SUCCESS):
if key in dataset.features:
views.append(rrb.TimeSeriesView(origin=key, name=key))
return rrb.Blueprint(rrb.Grid(*views))
def to_hwc_uint16_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
check_chw_float32(chw_float32_torch)
hwc_uint16_numpy = chw_float32_torch.round().type(torch.uint16).permute(1, 2, 0).numpy()
return hwc_uint16_numpy
def visualize_dataset(
dataset: LeRobotDataset,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
mode: str = "local",
web_port: int = 9090,
web_port: int | None = None,
grpc_port: int = 9876,
save: bool = False,
output_dir: Path | None = None,
display_compressed_images: bool = False,
display_mode: str = "rerun",
host: str = "127.0.0.1",
autoplay: bool = True,
**kwargs,
) -> Path | None:
if display_mode == "foxglove":
from lerobot.utils.foxglove_visualization import serve_foxglove_dataset_playback
logging.info("Starting Foxglove server")
serve_foxglove_dataset_playback(
dataset,
episode_index,
host=host,
port=web_port if web_port is not None else DEFAULT_FOXGLOVE_PORT,
compress_images=display_compressed_images,
autoplay=autoplay,
)
return None
if save:
assert output_dir is not None, (
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
@@ -188,14 +221,23 @@ def visualize_dataset(
if mode == "distant":
server_uri = rr.serve_grpc(grpc_port=grpc_port)
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
rr.serve_web_viewer(
open_browser=False,
web_port=web_port if web_port is not None else DEFAULT_RERUN_PORT,
connect_to=server_uri,
)
logging.info("Logging to Rerun")
# Depth frames and stats are dequantized to the dataset's depth_output_unit on load.
depth_meter = 1000.0 if dataset.depth_output_unit == DEPTH_MILLIMETER_UNIT else 1.0
# Use the dataset's q01/q99 depth statistics for robust depth range bounds
depth_ranges = {}
for key in dataset.meta.depth_keys:
stats = dataset.meta.stats[key]
stats = (dataset.meta.stats or {}).get(key)
if not stats:
continue
lo = stats["q01"] if "q01" in stats else stats["min"]
hi = stats["q99"] if "q99" in stats else stats["max"]
depth_ranges[key] = (float(np.asarray(lo).item()), float(np.asarray(hi).item()))
@@ -213,11 +255,12 @@ def visualize_dataset(
# display each camera image (or depth map)
for key in dataset.meta.camera_keys:
if key in dataset.meta.depth_keys:
depth = to_hwc_uint16_numpy(batch[key][i])
depth = to_hwc_float32_numpy(batch[key][i])
depth_entity = rr.DepthImage(
depth,
meter=depth_meter,
colormap=rr.components.Colormap.Viridis,
depth_range=depth_ranges[key],
depth_range=depth_ranges.get(key),
)
rr.log(key, entity=depth_entity)
else:
@@ -239,8 +282,8 @@ def visualize_dataset(
if REWARD in batch:
rr.log(REWARD, rr.Scalars(batch[REWARD][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalars(batch["next.success"][i].item()))
if SUCCESS in batch:
rr.log(SUCCESS, rr.Scalars(batch[SUCCESS][i].item()))
# save .rrd locally
if mode == "local" and save:
@@ -312,13 +355,11 @@ def main():
parser.add_argument(
"--web-port",
type=int,
default=9090,
help="Web port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--ws-port",
type=int,
help="deprecated, please use --grpc-port instead.",
default=None,
help=(
"Web/WebSocket port. For rerun `--mode distant` it is the web viewer port (default 9090); "
"for `--display-mode foxglove` it is the server bind port (default 8765)."
),
)
parser.add_argument(
"--grpc-port",
@@ -351,24 +392,56 @@ def main():
parser.add_argument(
"--display-compressed-images",
action="store_true",
help="If set, display compressed images in Rerun instead of uncompressed ones.",
help="If set, display compressed (JPEG) images instead of uncompressed ones.",
)
parser.add_argument(
"--display-mode",
type=str,
default="rerun",
choices=["rerun", "foxglove"],
help=(
"Visualization backend. 'rerun' uses the Rerun viewer (--mode/--save/--*-port apply). "
"'foxglove' starts a Foxglove WebSocket server that serves the episode as a seekable, "
"scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--web-port)."
),
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help=(
"Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set "
"(127.0.0.1 for local only, 0.0.0.0 for all interfaces)."
),
)
parser.add_argument(
"--no-autoplay",
dest="autoplay",
action="store_false",
help=(
"For `--display-mode foxglove`: don't start playing automatically when a client "
"connects; wait for play to be pressed in the Foxglove app instead."
),
)
args = parser.parse_args()
if args.display_mode == "foxglove":
rerun_only = ("mode", "save", "output_dir", "grpc_port", "batch_size", "num_workers")
ignored = [name for name in rerun_only if getattr(args, name) != parser.get_default(name)]
if ignored:
logging.warning(
"These flags only apply to `--display-mode rerun` and are ignored with "
"`--display-mode foxglove`: %s.",
", ".join(f"--{name.replace('_', '-')}" for name in ignored),
)
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
if kwargs["ws_port"] is not None:
logging.warning(
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
)
logging.warning("Setting grpc_port to ws_port value.")
kwargs["grpc_port"] = kwargs.pop("ws_port")
else:
kwargs.pop("ws_port") # Always remove ws_port from kwargs
init_logging()
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
+2 -3
View File
@@ -56,7 +56,6 @@ import threading
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from functools import partial
@@ -86,7 +85,7 @@ from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_proces
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.device_utils import get_safe_autocast_context, get_safe_torch_device
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
@@ -698,7 +697,7 @@ def eval_main(cfg: EvalPipelineConfig):
max_episodes_rendered = 0 if cfg.eval.recording else 10
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
with torch.no_grad(), get_safe_autocast_context(device, enabled=cfg.policy.use_amp):
info = eval_policy_all(
envs=envs,
policy=policy,
+28 -7
View File
@@ -38,6 +38,9 @@ lerobot-record \\
--display_data=true
```
To stream the data to Foxglove instead of Rerun, add ``--display_mode=foxglove`` (then connect the
Foxglove app to ``ws://127.0.0.1:8765``; override the port with ``--display_port=<port>``).
Example recording with bimanual so100:
```shell
lerobot-record \\
@@ -157,7 +160,11 @@ from lerobot.utils.utils import (
init_logging,
log_say,
)
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
from lerobot.utils.visualization_utils import (
init_visualization,
log_visualization_data,
shutdown_visualization,
)
@dataclass
@@ -168,11 +175,14 @@ class RecordConfig:
teleop: TeleoperatorConfig | None = None
# Display all cameras on screen
display_data: bool = False
# Display data on a remote Rerun server
# Visualization backend used when display_data is True: "rerun" or "foxglove".
display_mode: str = "rerun"
# For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket
# server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces).
display_ip: str | None = None
# Port of the remote Rerun server
# For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to.
display_port: int | None = None
# Whether to display compressed images in Rerun
# Whether to display compressed (JPEG) images instead of raw frames
display_compressed_images: bool = False
# Use vocal synthesis to read events.
play_sounds: bool = True
@@ -233,6 +243,7 @@ def record_loop(
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
display_mode: str = "rerun",
display_compressed_images: bool = False,
):
if dataset is not None and dataset.fps != fps:
@@ -327,8 +338,11 @@ def record_loop(
dataset.add_frame(frame)
if display_data:
log_rerun_data(
observation=obs_processed, action=action_values, compress_images=display_compressed_images
log_visualization_data(
display_mode,
observation=obs_processed,
action=action_values,
compress_images=display_compressed_images,
)
dt_s = time.perf_counter() - start_loop_t
@@ -354,7 +368,9 @@ def record(
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
init_rerun(session_name="recording", ip=cfg.display_ip, port=cfg.display_port)
init_visualization(
cfg.display_mode, session_name="recording", ip=cfg.display_ip, port=cfg.display_port
)
display_compressed_images = (
True
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
@@ -464,6 +480,7 @@ def record(
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
display_mode=cfg.display_mode,
display_compressed_images=display_compressed_images,
)
@@ -485,6 +502,7 @@ def record(
control_time_s=cfg.dataset.reset_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
display_mode=cfg.display_mode,
)
if events["rerecord_episode"]:
@@ -510,6 +528,9 @@ def record(
if listener is not None:
listener.stop()
if cfg.display_data:
shutdown_visualization(cfg.display_mode)
if cfg.dataset.push_to_hub:
if dataset and dataset.num_episodes > 0:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
+13 -3
View File
@@ -145,6 +145,9 @@ Usage examples
--dataset.rgb_encoder.vcodec=h264 \\
--dataset.rgb_encoder.preset=fast \\
--dataset.rgb_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2}
# Stream to Foxglove instead of Rerun:
# add --display_mode=foxglove, then connect the Foxglove app to ws://127.0.0.1:8765.
"""
import logging
@@ -190,7 +193,7 @@ from lerobot.teleoperators import ( # noqa: F401
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
from lerobot.utils.visualization_utils import init_rerun
from lerobot.utils.visualization_utils import init_visualization, shutdown_visualization
logger = logging.getLogger(__name__)
@@ -201,8 +204,13 @@ def rollout(cfg: RolloutConfig):
init_logging()
if cfg.display_data:
logger.info("Initializing Rerun visualization (ip=%s, port=%s)", cfg.display_ip, cfg.display_port)
init_rerun(session_name="rollout", ip=cfg.display_ip, port=cfg.display_port)
logger.info(
"Initializing %s visualization (ip=%s, port=%s)",
cfg.display_mode,
cfg.display_ip,
cfg.display_port,
)
init_visualization(cfg.display_mode, session_name="rollout", ip=cfg.display_ip, port=cfg.display_port)
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
@@ -227,6 +235,8 @@ def rollout(cfg: RolloutConfig):
logger.info("Interrupted by user")
finally:
strategy.teardown(ctx)
if cfg.display_data:
shutdown_visualization(cfg.display_mode)
logger.info("Rollout finished")
+39 -9
View File
@@ -31,6 +31,22 @@ lerobot-teleoperate \
--display_data=true
```
To stream the data to Foxglove instead of Rerun, add ``--display_mode=foxglove``
(then connect the Foxglove app to ``ws://127.0.0.1:8765``; override the port with ``--display_port=<port>``):
```shell
lerobot-teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--robot.id=black \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue \
--display_data=true \
--display_mode=foxglove
```
Example teleoperation with bimanual so100:
```shell
@@ -108,7 +124,11 @@ from lerobot.teleoperators import ( # noqa: F401
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, move_cursor_up
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
from lerobot.utils.visualization_utils import (
init_visualization,
log_visualization_data,
shutdown_visualization,
)
@dataclass
@@ -121,11 +141,14 @@ class TeleoperateConfig:
teleop_time_s: float | None = None
# Display all cameras on screen
display_data: bool = False
# Display data on a remote Rerun server
# Visualization backend used when display_data is True: "rerun" or "foxglove".
display_mode: str = "rerun"
# For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket
# server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces).
display_ip: str | None = None
# Port of the remote Rerun server
# For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to.
display_port: int | None = None
# Whether to display compressed images in Rerun
# Whether to display compressed (JPEG) images instead of raw frames
display_compressed_images: bool = False
@@ -137,6 +160,7 @@ def teleop_loop(
robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction],
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
display_data: bool = False,
display_mode: str = "rerun",
duration: float | None = None,
display_compressed_images: bool = False,
):
@@ -149,8 +173,10 @@ def teleop_loop(
teleop: The teleoperator device instance providing control actions.
robot: The robot instance being controlled.
fps: The target frequency for the control loop in frames per second.
display_data: If True, fetches robot observations and displays them in the console and Rerun.
display_compressed_images: If True, compresses images before sending them to Rerun for display.
display_data: If True, fetches robot observations and displays them in the console and the
visualization backend.
display_mode: Visualization backend to use when display_data is True ("rerun" or "foxglove").
display_compressed_images: If True, compresses images before sending them to the backend for display.
duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
@@ -187,7 +213,8 @@ def teleop_loop(
# Process robot observation through pipeline
obs_transition = robot_observation_processor(obs)
log_rerun_data(
log_visualization_data(
display_mode,
observation=obs_transition,
action=teleop_action,
compress_images=display_compressed_images,
@@ -215,7 +242,9 @@ def teleoperate(cfg: TeleoperateConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
init_rerun(session_name="teleoperation", ip=cfg.display_ip, port=cfg.display_port)
init_visualization(
cfg.display_mode, session_name="teleoperation", ip=cfg.display_ip, port=cfg.display_port
)
display_compressed_images = (
True
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
@@ -235,6 +264,7 @@ def teleoperate(cfg: TeleoperateConfig):
robot=robot,
fps=cfg.fps,
display_data=cfg.display_data,
display_mode=cfg.display_mode,
duration=cfg.teleop_time_s,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
@@ -245,7 +275,7 @@ def teleoperate(cfg: TeleoperateConfig):
pass
finally:
if cfg.display_data:
shutdown_rerun()
shutdown_visualization(cfg.display_mode)
teleop.disconnect()
robot.disconnect()
+4
View File
@@ -211,8 +211,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
force_cpu = cfg.trainable_config.device == "cpu"
# Drive Accelerate's autocast from policy.dtype (bf16/fp16 activate it; float32/absent -> launcher default).
policy_dtype = getattr(cfg.trainable_config, "dtype", None)
mixed_precision = {"bfloat16": "bf16", "float16": "fp16", "float32": "no"}.get(policy_dtype)
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
mixed_precision=mixed_precision,
kwargs_handlers=[ddp_kwargs],
cpu=force_cpu,
)
+7 -1
View File
@@ -33,7 +33,12 @@ from .constants import (
REWARD,
)
from .decorators import check_if_already_connected, check_if_not_connected
from .device_utils import auto_select_torch_device, get_safe_torch_device, is_torch_device_available
from .device_utils import (
auto_select_torch_device,
get_safe_autocast_context,
get_safe_torch_device,
is_torch_device_available,
)
from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from .import_utils import is_package_available, require_package
@@ -51,6 +56,7 @@ __all__ = [
"REWARD",
# Device utilities
"auto_select_torch_device",
"get_safe_autocast_context",
"get_safe_torch_device",
"is_torch_device_available",
# Import guards
+1
View File
@@ -37,6 +37,7 @@ ACTION_TOKEN_MASK = ACTION + ".token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
SUCCESS = "next.success"
INFO = "info"
ROBOTS = "robots"
+23
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from contextlib import AbstractContextManager, nullcontext
import torch
@@ -107,3 +108,25 @@ def is_amp_available(device: str):
return False
else:
raise ValueError(f"Unknown device '{device}.")
def get_safe_autocast_context(
device: str | torch.device,
*,
dtype: torch.dtype | None = None,
enabled: bool = True,
) -> AbstractContextManager:
"""Return a ``torch.autocast`` context, or a no-op when AMP is unsupported on ``device``.
Autocast is only entered on devices that support AMP (cuda, xpu, cpu); on mps and any
unknown device this falls back to ``nullcontext()`` so callers can request autocast
unconditionally without breaking on unsupported backends.
"""
device_type = device.type if isinstance(device, torch.device) else str(device).split(":", 1)[0]
try:
amp_ok = is_amp_available(device_type)
except ValueError:
amp_ok = False
if not enabled or not amp_ok:
return nullcontext()
return torch.autocast(device_type=device_type, dtype=dtype)
+651
View File
@@ -0,0 +1,651 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Foxglove visualization backend.
Live control-loop streaming (:func:`log_foxglove_data`) and seekable dataset playback
(:func:`serve_foxglove_dataset_playback`) over a Foxglove WebSocket server. Callers usually select a
backend at runtime through the dispatch in :mod:`lerobot.utils.visualization_utils` rather than
importing from here directly. Requires the ``viz`` extra (``pip install 'lerobot[viz]'``).
"""
import logging
import numbers
import time
import cv2
import numpy as np
from lerobot.types import RobotAction, RobotObservation
from .constants import (
ACTION,
ACTION_PREFIX,
DONE,
OBS_IMAGES,
OBS_PREFIX,
OBS_STATE,
OBS_STR,
REWARD,
SUCCESS,
TRUNCATED,
)
from .import_utils import require_package
# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}``
# pairs rather than one field per feature, so the same schema fits any robot regardless of which
# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name
# each series automatically, so a single filtered path plots every feature, e.g.
# ``/observation/state.scalars[:]``.
_SCALARS_SCHEMA = {
"type": "object",
"title": "lerobot.Scalars",
"properties": {
"scalars": {
"type": "array",
"items": {
"type": "object",
"properties": {
"label": {"type": "string"},
"value": {"type": "number"},
},
},
}
},
}
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def init_foxglove(host: str = "127.0.0.1", port: int | None = 8765) -> None:
"""
Starts a Foxglove WebSocket server for visualizing the control loop.
Connect to it from the Foxglove app at ``ws://<host>:<port>``. Calling this
more than once is a no-op while a server is already running.
Args:
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to (defaults to 8765).
"""
require_package("foxglove-sdk", extra="viz", import_name="foxglove")
import foxglove
# Live-stream state lives as attributes on ``log_foxglove_data``:
# ``.server`` is the shared WebSocket server and
# ``.channels`` caches one Foxglove channel per topic
if getattr(log_foxglove_data, "server", None) is not None:
return
log_foxglove_data.server = foxglove.start_server(host=host, port=port or 8765)
log_foxglove_data.channels = {}
def shutdown_foxglove() -> None:
"""Stops the Foxglove WebSocket server and clears cached channels."""
server = getattr(log_foxglove_data, "server", None)
if server is not None:
server.stop()
log_foxglove_data.server = None
log_foxglove_data.channels = {}
def _foxglove_safe_name(name: str) -> str:
"""Replace ``.`` with ``_`` so a feature name is a single Foxglove topic-path segment.
Foxglove treats ``.`` as a path separator, so an unsanitized name like ``observation.images.front``
would split into nested segments instead of naming one topic.
"""
return name.replace(".", "_")
def _foxglove_topic(key: str, *, is_image: bool = False) -> str:
"""Build the Foxglove topic for a feature ``key``.
Camera features map to a per-source image topic (``/observation/images/<name>``); scalar features
share one aggregate topic per source: ``/observation/state`` for observations, ``/action/state``
for actions.
"""
if is_image:
name = str(key)
for prefix in (f"{OBS_IMAGES}.", OBS_PREFIX):
if name.startswith(prefix):
name = name[len(prefix) :]
break
return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}"
source = ACTION if (str(key).startswith(ACTION_PREFIX) or str(key) == ACTION) else OBS_STR
return f"/{source}/state"
def _log_foxglove_scalars(
topic: str, values: dict[str, float], *, channels: dict | None = None, log_time: int | None = None
) -> None:
"""Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`.
``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of
``{label, value}`` objects. Insertion order is preserved so series stay stable across messages.
``channels`` is the per-topic channel cache to reuse (defaults to the live-stream cache on
:func:`log_foxglove_data`; dataset playback passes its own local cache to stay self-contained).
``log_time`` is the message time in nanoseconds; when ``None`` the server's receive time is used.
"""
if not values:
return
import foxglove
if channels is None:
channels = log_foxglove_data.channels
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = foxglove.Channel(topic, schema=_SCALARS_SCHEMA, message_encoding="json")
msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]}
if log_time is None:
channel.log(msg)
else:
channel.log(msg, log_time=log_time)
def _labeled_scalars(name: str, values, labels: list[str] | None = None) -> dict[str, float]:
"""Expand a 1D sequence into ``{label: value}`` entries with a consistent fallback."""
flat = [float(v) for v in values]
if labels is None or len(labels) != len(flat):
labels = [f"{name}_{i}" for i in range(len(flat))]
return dict(zip(labels, flat, strict=True))
def _log_foxglove_image(
topic: str,
frame_id: str,
arr: np.ndarray,
*,
compress_images: bool,
channels: dict | None = None,
log_time: int | None = None,
depth_range: tuple[float, float] | None = None,
raw_depth_values: bool = False,
) -> None:
"""Log an image on a cached per-topic channel.
The encoding is chosen from the channel count and dtype: a single-channel ``float`` or ``uint16``
frame is a depth map (``32FC1``/``16UC1``), single-channel ``uint8`` is ``mono8``, 3 => ``rgb8``
(float input assumed in [0, 1], cast to uint8), 4 => ``rgba8``; other counts are skipped with a
warning. When ``compress_images`` is set, ``rgb8`` is JPEG-encoded instead.
Args:
topic: Foxglove topic to log on.
frame_id: Frame id stamped on the message.
arr: Image as HWC or CHW (CHW is transposed to HWC), any dtype.
compress_images: JPEG-encode ``rgb8`` frames; ignored for other encodings.
channels: Per-topic channel cache to reuse (see :func:`_log_foxglove_scalars`).
log_time: Message time in nanoseconds, also written to the header timestamp; when ``None``
the server's receive time is used.
depth_range: ``(lo, hi)`` clip bounds in a depth frame's own input units. Depth frames
(``32FC1``/``16UC1``) are rescaled onto Foxglove's default display max for their encoding
(``1.0`` / ``10000``) so they show with sensible contrast; ``depth_range`` sets the source
range, else the frame's own min/max is used. Ignored for ``mono8``/``rgb8``/``rgba8``.
raw_depth_values: If True, depth values are not rescaled and are logged as is.
"""
from foxglove.channels import CompressedImageChannel, RawImageChannel
from foxglove.messages import CompressedImage, RawImage, Timestamp
if channels is None:
channels = log_foxglove_data.channels
time_ns = time.time_ns() if log_time is None else log_time
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
log_kwargs = {} if log_time is None else {"log_time": log_time}
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
height, width = arr.shape[0], arr.shape[1]
n_channels = 1 if arr.ndim == 2 else arr.shape[2]
if n_channels == 1 and arr.dtype != np.uint8:
# Depth map: infer the encoding from the dtype.
encoding, target_dtype, value_max = (
("32FC1", np.float32, 1.0)
if np.issubdtype(arr.dtype, np.floating)
else ("16UC1", np.uint16, 10000.0)
)
if not raw_depth_values:
# Rescale onto the encoding's display max with respect to the given depth_range.
lo, hi = depth_range if depth_range is not None else (float(arr.min()), float(arr.max()))
arr = arr.clip(lo, hi).astype(np.float32)
arr = (arr - lo) / ((hi - lo) if hi > lo else 1.0) * value_max
arr = np.ascontiguousarray(arr, dtype=target_dtype)
else:
if n_channels == 3 and np.issubdtype(arr.dtype, np.floating):
arr = (arr * 255.0).clip(0, 255)
arr = np.ascontiguousarray(arr, dtype=np.uint8)
if compress_images and n_channels == 3:
buf_src = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
_, buf = cv2.imencode(".jpg", buf_src)
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = CompressedImageChannel(topic=topic)
channel.log(
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"),
**log_kwargs,
)
return
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(n_channels)
if encoding is None:
logging.warning(
"Foxglove: skipping image on topic '%s' with unsupported shape %s (%d channels); "
"expected 1 (mono8/16UC1/32FC1), 3 (rgb8), or 4 (rgba8) channels.",
topic,
tuple(arr.shape),
n_channels,
)
return
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = RawImageChannel(topic=topic)
channel.log(
RawImage(
timestamp=timestamp,
frame_id=frame_id,
width=width,
height=height,
encoding=encoding,
step=width * n_channels * arr.itemsize,
data=arr.tobytes(),
),
**log_kwargs,
)
def log_foxglove_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to a Foxglove WebSocket server for real-time visualization.
Mirrors ``log_rerun_data`` but emits Foxglove messages over the server started by
:func:`init_foxglove`. Data is mapped as follows:
- Scalars (and elements of 1D arrays) are accumulated per source and logged on the
``/observation/state`` and ``/action/state`` topics as typed JSON messages using the static
``lerobot.Scalars`` schema: a ``scalars`` array of ``{label, value}`` objects (see
:data:`_SCALARS_SCHEMA`). The ``label`` field lets Foxglove name each series automatically, so
``/observation/state.scalars[:].value`` plots every feature at once.
- 3D NumPy arrays that resemble images are transposed from CHW to HWC when needed and logged on a
per-source topic (e.g. ``/observation/images/front``) as a ``RawImage`` (or a JPEG
``CompressedImage`` when ``compress_images`` is True).
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to JPEG-compress images before logging to save bandwidth in exchange
for CPU and quality.
"""
require_package("foxglove-sdk", extra="viz", import_name="foxglove")
if getattr(log_foxglove_data, "server", None) is None:
raise RuntimeError("init_foxglove() must be called before log_foxglove_data().")
now = time.time_ns()
if observation:
obs_scalars: dict[str, float] = {}
for k, v in observation.items():
if v is None:
continue
key = k[len(OBS_PREFIX) :] if str(k).startswith(OBS_PREFIX) else str(k)
if _is_scalar(v):
obs_scalars[key] = float(v)
elif isinstance(v, np.ndarray):
if v.ndim == 1:
obs_scalars.update(_labeled_scalars(key, v))
else:
_log_foxglove_image(
_foxglove_topic(k, is_image=True),
key,
v,
compress_images=compress_images,
log_time=now,
)
_log_foxglove_scalars(_foxglove_topic(OBS_STATE), obs_scalars, log_time=now)
if action:
action_scalars: dict[str, float] = {}
for k, v in action.items():
if v is None:
continue
key = k[len(ACTION_PREFIX) :] if str(k).startswith(ACTION_PREFIX) else str(k)
if _is_scalar(v):
action_scalars[key] = float(v)
elif isinstance(v, np.ndarray):
action_scalars.update(_labeled_scalars(key, v.flatten()))
_log_foxglove_scalars(_foxglove_topic(ACTION), action_scalars, log_time=now)
# ── Dataset playback over a Foxglove WebSocket server ─────────────────────
# A LeRobotDataset is random-access on disk, so rather than fire-and-forget a forward stream we
# advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays
# to in the Foxglove app. This relies on the SDK's PlaybackControl capability.
def _feature_dim_names(feature: dict | None) -> list[str] | None:
"""Best-effort per-dimension series labels for a 1D feature, or ``None`` to fall back to indices.
LeRobot records a feature's ``names`` inconsistently: a flat list (``["x", "y"]``), a category
mapping (``{"motors": ["motor_0", "motor_1"]}``), or a name->index mapping
(``{"delta_x": 0, "delta_y": 1}``). Each is handled, but labels are only returned when their count
matches the feature's 1D shape, so a malformed/mismatched ``names`` can't silently mislabel series.
"""
if not feature:
return None
shape = feature.get("shape")
dim = shape[0] if shape and len(shape) == 1 else None
names = feature.get("names")
labels: list[str] | None = None
if isinstance(names, dict):
values = list(names.values())
if values and all(isinstance(v, (list, tuple)) for v in values):
labels = [str(n) for group in values for n in group]
elif values and all(isinstance(v, int) and not isinstance(v, bool) for v in values):
labels = [name for name, _ in sorted(names.items(), key=lambda kv: kv[1])]
elif isinstance(names, (list, tuple)):
labels = [str(n) for n in names]
if labels is not None and dim is not None and len(labels) == dim:
return labels
return None
def _frame_to_scalars(sample: dict, key: str, labels: list[str] | None = None) -> dict[str, float]:
"""Flatten a frame's vector/scalar feature ``key`` into ``{label: value}`` entries.
``labels`` provides one name per dimension (from the dataset's feature metadata); when absent or
the wrong length, dimensions fall back to ``{name}_{i}`` (the short feature name), matching the
live stream so series names agree. A scalar feature becomes a single entry. Missing or ``None``
features yield an empty mapping.
"""
v = sample.get(key)
if v is None:
return {}
arr = v.numpy() if hasattr(v, "numpy") else np.asarray(v)
if key.startswith(OBS_PREFIX):
name = key[len(OBS_PREFIX) :]
elif key.startswith(ACTION_PREFIX):
name = key[len(ACTION_PREFIX) :]
else:
name = key
if arr.ndim == 0:
return {name: float(arr)}
return _labeled_scalars(name, arr.flatten(), labels)
def serve_foxglove_dataset_playback(
dataset,
episode_index: int,
*,
host: str = "127.0.0.1",
port: int = 8765,
compress_images: bool = False,
autoplay: bool = True,
) -> None:
"""Serve a single dataset episode to Foxglove as a seekable, scrubbable timeline.
Starts a Foxglove WebSocket server advertising the ``PlaybackControl`` capability over the
episode's time range. The Foxglove app drives play/pause/seek/speed; a background thread and a
``ServerListener`` read frames from the on-disk ``dataset`` on demand and log them stamped at
their dataset timestamps, so the user can scrub anywhere in the episode. Blocks until interrupted.
Args:
dataset: A ``LeRobotDataset`` loaded for the single episode to visualize.
episode_index: Index of the episode being visualized (used only for the session name).
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to.
compress_images: Whether to JPEG-compress camera frames before logging.
autoplay: If True, start playing automatically as soon as a client connects, instead of
waiting for the user to press play in the Foxglove app.
"""
require_package("foxglove-sdk", extra="viz", import_name="foxglove")
import bisect
import threading
import foxglove
from foxglove.websocket import (
Capability,
PlaybackCommand,
PlaybackControlRequest,
PlaybackState,
PlaybackStatus,
ServerListener,
)
# Per-frame timestamps in nanoseconds (read straight from the table, no video decode).
times_ns = [int(round(float(t) * 1e9)) for t in dataset.hf_dataset["timestamp"]]
n_frames = len(times_ns)
if n_frames == 0:
raise ValueError("Cannot visualize an empty episode.")
first_ns, last_ns = times_ns[0], times_ns[-1]
camera_keys = list(dataset.meta.camera_keys)
# Dataset-wide q01/q99 depth bounds (fallback min/max) used to normalize depth to [0, 1].
depth_ranges: dict[str, tuple[float, float]] = {}
for key in dataset.meta.depth_keys:
stats = (dataset.meta.stats or {}).get(key)
if not stats:
continue
lo = stats["q01"] if "q01" in stats else stats["min"]
hi = stats["q99"] if "q99" in stats else stats["max"]
depth_ranges[key] = (float(np.asarray(lo).item()), float(np.asarray(hi).item()))
# Per-dimension series labels from the dataset metadata (e.g. joint names), computed once.
scalar_labels = {
OBS_STATE: _feature_dim_names(dataset.meta.features.get(OBS_STATE)),
ACTION: _feature_dim_names(dataset.meta.features.get(ACTION)),
}
# Local channel cache so the playback server is self-contained and doesn't touch the live-stream cache.
channels: dict = {}
def emit_frame(i: int) -> None:
"""Log every channel for frame ``i`` stamped at its dataset timestamp."""
sample = dataset[i]
log_time = times_ns[i]
for key in camera_keys:
arr = sample.get(key)
if arr is None:
continue
arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr)
_log_foxglove_image(
_foxglove_topic(key, is_image=True),
key,
arr,
compress_images=compress_images,
channels=channels,
log_time=log_time,
depth_range=depth_ranges.get(key),
raw_depth_values=True,
)
_log_foxglove_scalars(
_foxglove_topic(OBS_STATE),
_frame_to_scalars(sample, OBS_STATE, scalar_labels[OBS_STATE]),
channels=channels,
log_time=log_time,
)
_log_foxglove_scalars(
_foxglove_topic(ACTION),
_frame_to_scalars(sample, ACTION, scalar_labels[ACTION]),
channels=channels,
log_time=log_time,
)
episode_scalars = {}
for feat, label in (
(DONE, "done"),
(TRUNCATED, "truncated"),
(REWARD, "reward"),
(SUCCESS, "success"),
):
v = sample.get(feat)
if v is not None:
episode_scalars[label] = float(v)
_log_foxglove_scalars("/episode/state", episode_scalars, channels=channels, log_time=log_time)
lock = threading.Lock()
stop_event = threading.Event()
# Shared playback state, guarded by ``lock``. ``seek_idx`` is a one-shot request set by the
# listener and serviced by the playback loop, which is the *only* thread that emits frames (so
# concurrent random access into the on-disk dataset / video decoder never overlaps).
state = {
"status": PlaybackStatus.Paused,
"cursor": first_ns,
"speed": 1.0,
"last_idx": -1,
"seek_idx": None,
}
def index_at(t_ns: int) -> int:
return max(0, min(n_frames - 1, bisect.bisect_right(times_ns, t_ns) - 1))
# One-shot latch so autoplay fires only on the first client subscription.
autoplay_started = threading.Event()
class _PlaybackListener(ServerListener):
def on_subscribe(self, client, channel):
# Start playing automatically once a client actually connects (subscribes). Using the
# subscribe hook, rather than starting in Playing up front, means the timeline doesn't
# advance before anyone is watching. Fires once; the user can still pause/seek after.
if not autoplay:
return
with lock:
if autoplay_started.is_set() or state["status"] != PlaybackStatus.Paused:
return
autoplay_started.set()
state["status"] = PlaybackStatus.Playing
cursor, speed = state["cursor"], state["speed"]
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Playing, cursor, speed, False, ""))
def on_playback_control_request(self, req: PlaybackControlRequest):
# Only mutate state here; the playback loop performs all frame emission.
with lock:
did_seek = False
if req.seek_time is not None:
cursor = max(first_ns, min(last_ns, req.seek_time))
state["cursor"] = cursor
state["last_idx"] = state["seek_idx"] = index_at(cursor)
did_seek = True
if req.playback_speed and req.playback_speed > 0:
state["speed"] = req.playback_speed
if req.playback_command == PlaybackCommand.Play:
# Restarting from the end replays from the beginning.
if state["cursor"] >= last_ns:
state["cursor"] = first_ns
state["last_idx"] = state["seek_idx"] = 0
did_seek = True
state["status"] = PlaybackStatus.Playing
elif req.playback_command == PlaybackCommand.Pause:
state["status"] = PlaybackStatus.Paused
status, cursor, speed = state["status"], state["cursor"], state["speed"]
request_id = req.request_id or ""
return PlaybackState(status, cursor, speed, did_seek, request_id)
server = foxglove.start_server(
name=f"{dataset.repo_id}/episode_{episode_index}",
host=host,
port=port,
capabilities=[Capability.PlaybackControl, Capability.Time],
server_listener=_PlaybackListener(),
playback_time_range=(first_ns, last_ns),
)
def playback_loop() -> None:
# Cap how far the cursor may advance in a single tick. A slow frame decode (or any stall)
# would otherwise make ``dt`` huge and produce one enormous catch-up batch; clamping it makes
# playback trail wall-clock under a slow decoder while each tick emits a bounded frame range.
max_tick_dt_s = 0.25
prev = time.monotonic()
while not stop_event.is_set():
time.sleep(1.0 / 60.0)
ended = False
speed = 1.0
with lock:
now = time.monotonic()
dt = min(now - prev, max_tick_dt_s)
prev = now
# A queued seek is always serviced, even while paused, so scrubbing updates the view.
work = []
seek_idx = state["seek_idx"]
if seek_idx is not None:
state["seek_idx"] = None
work.append(seek_idx)
if state["status"] == PlaybackStatus.Playing:
cursor = state["cursor"] + int(dt * 1e9 * state["speed"])
start_idx = state["last_idx"] + 1
if cursor >= last_ns:
cursor, target, ended = last_ns, n_frames - 1, True
else:
target = index_at(cursor)
state["cursor"] = cursor
work.extend(range(start_idx, target + 1))
# cursor only grows while playing (seeks reset last_idx in the listener), so
# target >= last_idx here; a plain assignment is correct and clearer than max().
state["last_idx"] = target
if ended:
state["status"] = PlaybackStatus.Ended
if not work:
continue
cursor, speed = state["cursor"], state["speed"]
# Emit outside the lock; this is the only thread that calls emit_frame. Re-check
# stop_event between frames so shutdown stays responsive even mid-batch.
for i in work:
if stop_event.is_set():
break
emit_frame(i)
server.broadcast_time(cursor)
if ended:
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Ended, cursor, speed, False, ""))
# Emit the first frame so channels are advertised (done before the loop starts, so emission stays
# single-threaded). Late-connecting clients re-receive frames once they seek/play.
emit_frame(0)
with lock:
state["last_idx"] = 0
server.broadcast_time(first_ns)
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Paused, first_ns, 1.0, True, ""))
thread = threading.Thread(target=playback_loop, name="foxglove-playback", daemon=True)
thread.start()
print(f"Foxglove server running. Connect the Foxglove app to ws://{host}:{port}")
print("Use the playback controls in Foxglove to play/pause and scrub the episode. Ctrl-C to exit.")
try:
while not stop_event.is_set():
time.sleep(0.5)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
finally:
stop_event.set()
thread.join(timeout=2.0)
server.stop()
channels.clear()
+191
View File
@@ -0,0 +1,191 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rerun visualization backend.
Live control-loop streaming to the Rerun viewer (:func:`log_rerun_data`). Callers usually select a
backend at runtime through the dispatch in :mod:`lerobot.utils.visualization_utils` rather than
importing from here directly. Requires the ``viz`` extra (``pip install 'lerobot[viz]'``).
"""
import numbers
import os
import numpy as np
from lerobot.configs import DEPTH_MILLIMETER_UNIT, infer_depth_unit
from lerobot.types import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
from .import_utils import require_package
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def init_rerun(
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
) -> None:
"""
Initializes the Rerun SDK for visualizing the control loop.
Args:
session_name: Name of the Rerun session.
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
log_rerun_data.blueprint = None # Reset blueprint cache for new session
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
if ip and port:
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
else:
rr.spawn(memory_limit=memory_limit)
def shutdown_rerun() -> None:
"""Shuts down the Rerun SDK gracefully."""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
rr.rerun_shutdown()
def _build_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]):
"""Build a Rerun blueprint laying out camera images, observation and action scalars in separate views.
Camera images, observation and action scalars are arranged in a grid.
"""
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun.blueprint as rrb
views = [rrb.Spatial2DView(origin=path, name=path) for path in sorted(image_paths)]
if observation_paths:
views.append(rrb.TimeSeriesView(name="observation", contents=sorted(observation_paths)))
if action_paths:
views.append(rrb.TimeSeriesView(name="action", contents=sorted(action_paths)))
return rrb.Blueprint(rrb.Grid(*views))
def _ensure_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]) -> None:
"""Build and send the blueprint once, from the first observation and action data."""
if getattr(log_rerun_data, "blueprint", None) is not None:
return
if not (observation_paths or action_paths or image_paths):
return
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun as rr
blueprint = _build_blueprint(observation_paths, action_paths, image_paths)
log_rerun_data.blueprint = blueprint
rr.send_blueprint(blueprint)
def log_rerun_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to Rerun for real-time visualization.
This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately:
- Scalars values (floats, ints) are logged as `rr.Scalars`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every
dimension shares the same view instead of being split across one view per element.
- Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch.
Keys are automatically namespaced with "observation." or "action." if not already present.
On the first call, a blueprint is built and sent so observation and action scalars get separate
time-series views and each image gets its own spatial view.
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
observation_paths: set[str] = set()
action_paths: set[str] = set()
image_paths: set[str] = set()
if observation:
for k, v in observation.items():
if v is None:
continue
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
observation_paths.add(key)
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
rr.log(key, rr.Scalars(arr.astype(float)))
observation_paths.add(key)
else:
if arr.shape[-1] == 1:
# At record time, the depth unit is inferred from the frame type.
depth_unit = infer_depth_unit(arr.dtype)
img_entity = rr.DepthImage(
arr,
meter=1000.0 if depth_unit == DEPTH_MILLIMETER_UNIT else 1.0,
colormap=rr.components.Colormap.Viridis,
)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
image_paths.add(key)
if action:
for k, v in action.items():
if v is None:
continue
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
action_paths.add(key)
elif isinstance(v, np.ndarray):
# Flatten any (incl. higher-dimensional) array into a single batched Scalars
rr.log(key, rr.Scalars(v.reshape(-1).astype(float)))
action_paths.add(key)
_ensure_blueprint(observation_paths, action_paths, image_paths)
+44 -142
View File
@@ -12,166 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import os
"""Backend-agnostic visualization dispatch.
import numpy as np
Selects a visualization backend at runtime via a display-mode string (e.g. a ``--display_mode`` CLI
flag) so callers never branch on the backend. The concrete implementations live in
:mod:`lerobot.utils.rerun_visualization` and :mod:`lerobot.utils.foxglove_visualization`; importing
this module does not import ``rerun`` or ``foxglove`` (each backend imports its SDK lazily behind a
``require_package`` guard).
"""
from lerobot.types import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
from .import_utils import require_package
from .foxglove_visualization import init_foxglove, log_foxglove_data, shutdown_foxglove
from .rerun_visualization import init_rerun, log_rerun_data, shutdown_rerun
# Visualization backends selectable at runtime via a display-mode string (e.g. a --display_mode flag).
VISUALIZATION_MODES = ("rerun", "foxglove")
def init_rerun(
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
def init_visualization(
display_mode: str,
*,
session_name: str = "lerobot_control_loop",
ip: str | None = None,
port: int | None = None,
) -> None:
"""
Initializes the Rerun SDK for visualizing the control loop.
"""Initializes the visualization backend selected by ``display_mode``.
Args:
session_name: Name of the Rerun session.
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
For ``"rerun"``, ``ip``/``port`` point at an optional remote Rerun server. For ``"foxglove"``,
``ip`` is the interface to bind the WebSocket server to (``127.0.0.1`` for local only, ``0.0.0.0``
for all interfaces) and ``port`` is its port.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
log_rerun_data.blueprint = None # Reset blueprint cache for new session
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
if ip and port:
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
if display_mode == "rerun":
init_rerun(session_name=session_name, ip=ip, port=port)
elif display_mode == "foxglove":
init_foxglove(host=ip or "127.0.0.1", port=port)
else:
rr.spawn(memory_limit=memory_limit)
raise ValueError(f"Unknown display_mode '{display_mode}'. Expected one of {VISUALIZATION_MODES}.")
def shutdown_rerun() -> None:
"""Shuts down the Rerun SDK gracefully."""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
rr.rerun_shutdown()
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def _build_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]):
"""Build a Rerun blueprint laying out camera images, observation and action scalars in separate views.
Camera images, observation and action scalars are arranged in a grid.
"""
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun.blueprint as rrb
views = [rrb.Spatial2DView(origin=path, name=path) for path in sorted(image_paths)]
if observation_paths:
views.append(rrb.TimeSeriesView(name="observation", contents=sorted(observation_paths)))
if action_paths:
views.append(rrb.TimeSeriesView(name="action", contents=sorted(action_paths)))
return rrb.Blueprint(rrb.Grid(*views))
def _ensure_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]) -> None:
"""Build and send the blueprint once, from the first observation and action data."""
if getattr(log_rerun_data, "blueprint", None) is not None:
return
if not (observation_paths or action_paths or image_paths):
return
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun as rr
blueprint = _build_blueprint(observation_paths, action_paths, image_paths)
log_rerun_data.blueprint = blueprint
rr.send_blueprint(blueprint)
def log_rerun_data(
def log_visualization_data(
display_mode: str,
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to Rerun for real-time visualization.
"""Logs observation/action data to the backend selected by ``display_mode``."""
This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately:
- Scalars values (floats, ints) are logged as `rr.Scalars`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every
dimension shares the same view instead of being split across one view per element.
- Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch.
if display_mode == "rerun":
log_rerun_data(observation=observation, action=action, compress_images=compress_images)
elif display_mode == "foxglove":
log_foxglove_data(observation=observation, action=action, compress_images=compress_images)
else:
raise ValueError(f"Unknown display_mode '{display_mode}'. Expected one of {VISUALIZATION_MODES}.")
Keys are automatically namespaced with "observation." or "action." if not already present.
On the first call, a blueprint is built and sent so observation and action scalars get separate
time-series views and each image gets its own spatial view.
def shutdown_visualization(display_mode: str) -> None:
"""Shuts down the backend selected by ``display_mode``."""
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
observation_paths: set[str] = set()
action_paths: set[str] = set()
image_paths: set[str] = set()
if observation:
for k, v in observation.items():
if v is None:
continue
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
observation_paths.add(key)
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
rr.log(key, rr.Scalars(arr.astype(float)))
observation_paths.add(key)
else:
if arr.shape[-1] == 1:
img_entity = rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
image_paths.add(key)
if action:
for k, v in action.items():
if v is None:
continue
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
action_paths.add(key)
elif isinstance(v, np.ndarray):
# Flatten any (incl. higher-dimensional) array into a single batched Scalars
rr.log(key, rr.Scalars(v.reshape(-1).astype(float)))
action_paths.add(key)
_ensure_blueprint(observation_paths, action_paths, image_paths)
if display_mode == "rerun":
shutdown_rerun()
elif display_mode == "foxglove":
shutdown_foxglove()
else:
raise ValueError(f"Unknown display_mode '{display_mode}'. Expected one of {VISUALIZATION_MODES}.")
+2 -1
View File
@@ -1531,6 +1531,7 @@ def test_valid_video_codecs_constant():
assert "h264" in VALID_VIDEO_CODECS
assert "hevc" in VALID_VIDEO_CODECS
assert "libsvtav1" in VALID_VIDEO_CODECS
assert "libaom-av1" in VALID_VIDEO_CODECS
assert "auto" in VALID_VIDEO_CODECS
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
assert "h264_nvenc" in VALID_VIDEO_CODECS
@@ -1538,7 +1539,7 @@ def test_valid_video_codecs_constant():
assert "h264_qsv" in VALID_VIDEO_CODECS
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
assert "hevc_nvenc" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 10
assert len(VALID_VIDEO_CODECS) == 11
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
+80 -32
View File
@@ -32,6 +32,7 @@ from lerobot.configs.video import (
)
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
from lerobot.utils.constants import DEFAULT_FEATURES
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
@@ -247,42 +248,89 @@ class TestFeatureFileRouting:
dataset.finalize()
# ── 5. Depth stats unit canonicalization (millimetres) ────────────────
class TestDepthStatsUnit:
"""Depth stats are always stored in millimetres, regardless of raw frame dtype."""
class TestDepthUnitMetadata:
"""The depth unit is inferred once from dtype, stored in ``info``, and drives stats + reads."""
NUM_FRAMES = 4
@pytest.mark.parametrize("use_videos", [False, True])
def test_stats_canonicalized_to_mm(self, tmp_path, features_factory, use_videos):
"""Float (metre) and integer (millimetre) depth over the same physical range
yield identical millimetre-scale stats."""
def _record(self, root, features_factory, depth_dtype, value, use_videos):
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def _record(depth_dtype, root):
features = features_factory(
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=use_videos
)
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=features,
root=root,
use_videos=use_videos,
streaming_encoding=use_videos,
)
add_frames(dataset, num_frames=self.NUM_FRAMES, depth_dtype=depth_dtype)
dataset.save_episode()
dataset.finalize()
return np.asarray(dataset.meta.stats[DEPTH_KEY]["mean"]).reshape(-1)
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=use_videos)
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=features,
root=root,
use_videos=use_videos,
streaming_encoding=use_videos,
)
for _ in range(self.NUM_FRAMES):
frame: dict = {"task": "test"}
for key, ft in dataset.meta.features.items():
if key in DEFAULT_FEATURES:
continue
if key in dataset.meta.depth_keys:
frame[key] = np.full(ft["shape"], value, dtype=depth_dtype)
elif key in dataset.meta.camera_keys:
frame[key] = np.random.randint(0, 256, ft["shape"], dtype=np.uint8)
else:
frame[key] = np.zeros(ft["shape"], dtype=np.float32)
dataset.add_frame(frame)
return dataset
# add_frames ramps float depth over 0.110 m and integer depth over 10010000 mm
# (the same physical range), so canonicalized stats must match.
mean_m = _record(np.float32, tmp_path / "ds_m")
mean_mm = _record(np.uint16, tmp_path / "ds_mm")
@pytest.mark.parametrize("use_videos", [False, True])
@pytest.mark.parametrize(
("depth_dtype", "value", "expected_unit"),
[(np.float32, 2.0, DEPTH_METER_UNIT), (np.uint16, 2000, DEPTH_MILLIMETER_UNIT)],
)
def test_recorded_unit_inferred_persisted_and_kept_in_stats(
self, tmp_path, features_factory, use_videos, depth_dtype, value, expected_unit
):
"""Unit is inferred from the first frame's dtype, drives stats (raw, never canonicalized), and survives a reload."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Float (metre) input is scaled to millimetres, not left in the single-digit metre range.
assert mean_m.item() > 50.0
np.testing.assert_allclose(mean_m, mean_mm, rtol=0.05)
dataset = self._record(tmp_path / "ds", features_factory, depth_dtype, value, use_videos)
assert dataset.meta.features[DEPTH_KEY]["info"]["depth_unit"] == expected_unit
dataset.save_episode()
mean = float(np.asarray(dataset.meta.stats[DEPTH_KEY]["mean"]).reshape(-1)[0])
np.testing.assert_allclose(mean, value, rtol=0.05)
dataset.finalize()
reloaded = LeRobotDataset(repo_id=DUMMY_REPO_ID, root=tmp_path / "ds")
assert reloaded.meta.features[DEPTH_KEY]["info"]["depth_unit"] == expected_unit
@pytest.mark.parametrize("use_videos", [False, True])
@pytest.mark.parametrize(
("output_unit", "expected"),
[(DEPTH_MILLIMETER_UNIT, 2000.0), (DEPTH_METER_UNIT, 2.0)],
)
def test_read_honors_output_unit_for_frames_and_stats(
self, tmp_path, features_factory, use_videos, output_unit, expected
):
"""Reloading with a ``depth_output_unit`` converts metre frames (image mode) and rescales stats while preserving count."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = self._record(tmp_path / "ds", features_factory, np.float32, 2.0, use_videos=use_videos)
dataset.save_episode()
count = float(np.asarray(dataset.meta.stats[DEPTH_KEY]["count"]).reshape(-1)[0])
dataset.finalize()
read_dataset = LeRobotDataset(
repo_id=DUMMY_REPO_ID, root=tmp_path / "ds", depth_output_unit=output_unit
)
stats = read_dataset.meta.stats[DEPTH_KEY]
np.testing.assert_allclose(float(np.asarray(stats["mean"]).reshape(-1)[0]), expected, rtol=0.05)
np.testing.assert_allclose(float(np.asarray(stats["count"]).reshape(-1)[0]), count)
if not use_videos:
depth = read_dataset[0][DEPTH_KEY]
assert torch.allclose(depth, torch.full_like(depth, expected))
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
stream_dataset = StreamingLeRobotDataset(
repo_id=DUMMY_REPO_ID, root=tmp_path / "ds", depth_output_unit=output_unit
)
stream_depth = next(iter(stream_dataset))[DEPTH_KEY]
assert torch.allclose(stream_depth, torch.full_like(stream_depth, expected))
+3 -1
View File
@@ -345,7 +345,9 @@ class TestExtraOptions:
opts = cfg.get_codec_options()
assert opts["qp"] == 20
assert isinstance(opts["qp"], int)
assert cfg.get_codec_options(as_strings=True)["qp"] == "20"
str_opts = cfg.get_codec_options(as_strings=True)
assert str_opts["qp"] == "20"
assert all(isinstance(v, str) for v in str_opts.values())
@require_libsvtav1
def test_structured_fields_win_on_collision(self):
+15 -12
View File
@@ -26,6 +26,7 @@ import pytest
import torch
from datasets import Dataset
from lerobot.configs.video import infer_depth_unit
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.io_utils import flatten_dict, hf_transform_to_torch
@@ -49,18 +50,16 @@ from tests.fixtures.constants import (
)
def add_frames(dataset: LeRobotDataset, num_frames: int, depth_dtype: np.dtype = np.uint16) -> None:
def add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
"""Append ``num_frames`` synthetic frames to ``dataset``.
Generates per-feature payloads from ``dataset.meta``: depth ramps (``depth_dtype``,
default ``uint16`` millimetres; pass ``np.float32`` for metres) for keys in
``dataset.meta.depth_keys``, uint8 random noise for video/image keys, and float32
zeros for everything else. ``DEFAULT_FEATURES`` (timestamp, frame_index, ...) are
auto-populated by ``add_frame`` and skipped here.
Generates per-feature payloads from ``dataset.meta``: uint16 depth ramps for
keys in ``dataset.meta.depth_keys``, uint8 random noise for video/image keys,
and float32 zeros for everything else. ``DEFAULT_FEATURES`` (timestamp,
frame_index, ...) are auto-populated by ``add_frame`` and skipped here.
"""
video_keys = dataset.meta.video_keys
depth_keys = dataset.meta.depth_keys
depth_is_float = np.issubdtype(depth_dtype, np.floating)
# Smooth gradient base reused per (H, W) to keep depth frames cheap to
# encode (HEVC Main 12 hates white noise).
_depth_base_cache: dict[tuple[int, int], np.ndarray] = {}
@@ -72,14 +71,11 @@ def add_frames(dataset: LeRobotDataset, num_frames: int, depth_dtype: np.dtype =
shape = ft["shape"]
if key in depth_keys:
h, w, _ = shape
# Float depth is expressed in metres, integer depth in millimetres.
lo, hi = (0.1, 10.0) if depth_is_float else (100.0, 10_000.0)
base = _depth_base_cache.setdefault(
(h, w),
np.linspace(lo, hi, h * w, dtype=np.float32).reshape(h, w, 1),
np.linspace(100.0, 10_000.0, h * w, dtype=np.float32).reshape(h, w, 1),
)
step = (0.05 if depth_is_float else 50.0) * i
frame[key] = (base + step).clip(0, 65535).astype(depth_dtype)
frame[key] = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
elif key in video_keys:
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
else:
@@ -540,6 +536,13 @@ def lerobot_dataset_factory(
chunks_size=chunks_size,
**info_kwargs,
)
# This synthetic path skips add_frame, so record the depth unit the writer would
# have stored (dummy depth is uint16) to keep ``depth_unit`` present in info.json.
# Reassign a fresh info dict to avoid mutating the shared feature constants.
for ft in info.features.values():
ft_info = ft.get("info")
if ft_info is not None and ft_info.get("is_depth_map") and "depth_unit" not in ft_info:
ft["info"] = {**ft_info, "depth_unit": infer_depth_unit(np.uint16)}
if stats is None:
stats = stats_factory(features=info.features)
if tasks is None:
+40
View File
@@ -0,0 +1,40 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import pytest
import torch
from lerobot.utils.device_utils import get_safe_autocast_context
@pytest.mark.parametrize(
("device", "enabled", "expect_autocast"),
[
("cpu", True, True), # AMP-capable device -> real autocast
(torch.device("cpu"), True, True), # accepts torch.device
("cpu", False, False), # explicitly disabled -> no-op
("mps", True, False), # AMP unsupported on mps -> no-op
("privateuseone", True, False), # unknown device -> safe no-op
],
)
def test_get_safe_autocast_context(device, enabled, expect_autocast):
ctx = get_safe_autocast_context(device, dtype=torch.bfloat16, enabled=enabled)
if expect_autocast:
assert isinstance(ctx, torch.autocast)
with ctx:
assert torch.is_autocast_enabled("cpu")
else:
assert isinstance(ctx, nullcontext)
+101
View File
@@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Foxglove backend's pure helpers.
These cover topic naming, series labelling and feature-name parsing. They import
``foxglove_visualization`` directly and need NO ``foxglove`` extra: the SDK is imported lazily inside
the functions that talk to the server, so the helpers below run in the base test tier.
"""
import numpy as np
from lerobot.utils import foxglove_visualization as fv
from lerobot.utils.constants import ACTION, OBS_STATE
def test_foxglove_safe_name_collapses_dots():
assert fv._foxglove_safe_name("observation.images.front") == "observation_images_front"
assert fv._foxglove_safe_name("plain") == "plain"
def test_foxglove_topic_image_strips_prefix_without_doubling_images():
# Fully-qualified camera key -> single clean segment (no doubled "images").
assert fv._foxglove_topic("observation.images.front", is_image=True) == "/observation/images/front"
# A nested camera name keeps its structure via safe-name collapsing.
assert (
fv._foxglove_topic("observation.images.wrist.left", is_image=True) == "/observation/images/wrist_left"
)
# Bare camera name (as real robots emit).
assert fv._foxglove_topic("front", is_image=True) == "/observation/images/front"
def test_foxglove_topic_scalar_sources():
assert fv._foxglove_topic(OBS_STATE) == "/observation/state"
assert fv._foxglove_topic("observation.environment_state") == "/observation/state"
assert fv._foxglove_topic(ACTION) == "/action/state"
assert fv._foxglove_topic("action.delta") == "/action/state"
def test_labeled_scalars_uses_labels_then_index_fallback():
assert fv._labeled_scalars("state", np.array([1.0, 2.0, 3.0])) == {
"state_0": 1.0,
"state_1": 2.0,
"state_2": 3.0,
}
assert fv._labeled_scalars("state", [1.0, 2.0], ["pan", "lift"]) == {"pan": 1.0, "lift": 2.0}
# Wrong-length labels fall back to index naming (never silently mislabels).
assert fv._labeled_scalars("q", [1.0, 2.0], ["only_one"]) == {"q_0": 1.0, "q_1": 2.0}
def test_frame_to_scalars_matches_live_labeling_and_handles_scalar():
frame = {OBS_STATE: np.array([1.0, 2.0])}
# No metadata -> {short_name}_{i}, identical to the live-stream fallback.
assert fv._frame_to_scalars(frame, OBS_STATE) == fv._labeled_scalars("state", np.array([1.0, 2.0]))
assert fv._frame_to_scalars(frame, OBS_STATE) == {"state_0": 1.0, "state_1": 2.0}
# Metadata labels are honored.
assert fv._frame_to_scalars(frame, OBS_STATE, ["pan", "lift"]) == {"pan": 1.0, "lift": 2.0}
# A 0-d scalar becomes a single entry named by the short feature name.
assert fv._frame_to_scalars({ACTION: np.array(5.0)}, ACTION) == {"action": 5.0}
# A missing feature yields an empty mapping.
assert fv._frame_to_scalars({}, OBS_STATE) == {}
def test_feature_dim_names_formats():
# Flat list of names.
assert fv._feature_dim_names({"shape": [2], "names": ["x", "y"]}) == ["x", "y"]
# Category mapping (dict of lists).
assert fv._feature_dim_names({"shape": [2], "names": {"motors": ["m0", "m1"]}}) == ["m0", "m1"]
# name -> index mapping (returned sorted by index).
assert fv._feature_dim_names({"shape": [2], "names": {"delta_x": 0, "delta_y": 1}}) == [
"delta_x",
"delta_y",
]
# Bool values must NOT be treated as an index map (bool is a subclass of int).
assert fv._feature_dim_names({"shape": [2], "names": {"a": True, "b": False}}) is None
# Mismatched length -> None (won't silently mislabel).
assert fv._feature_dim_names({"shape": [3], "names": ["x", "y"]}) is None
# Missing / absent names -> None.
assert fv._feature_dim_names(None) is None
assert fv._feature_dim_names({"shape": [2]}) is None
def test_is_scalar():
assert fv._is_scalar(1.0)
assert fv._is_scalar(np.float32(2.0))
assert fv._is_scalar(np.array(3.0)) # 0-d array
assert not fv._is_scalar(np.array([1.0, 2.0]))
assert not fv._is_scalar("x")
+311
View File
@@ -0,0 +1,311 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import sys
from types import SimpleNamespace
import numpy as np
import pytest
pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])")
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
@pytest.fixture
def mock_rerun(monkeypatch):
"""
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
depend on the real library. Also reload the module-under-test so it binds to
this mock `rr`.
"""
calls = []
blueprints = []
class DummyScalar:
def __init__(self, value):
# Scalars may be built from a single float or from a 1D array batch.
self.value = value
class DummyImage:
def __init__(self, arr):
self.arr = arr
def compress(self, *a, **k):
return self
class DummyDepthImage:
def __init__(self, arr, meter=None, colormap=None):
self.arr = arr
self.meter = meter
self.colormap = colormap
def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs:
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
def dummy_send_blueprint(blueprint, *a, **k):
blueprints.append(blueprint)
# Mock the `rerun.blueprint` submodule used to build the layout.
dummy_rrb = SimpleNamespace(
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
kind="Spatial2DView", origin=origin, name=name
),
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
kind="TimeSeriesView", name=name, contents=contents
),
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
)
dummy_rr = SimpleNamespace(
__name__="rerun",
__package__="rerun",
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
Image=DummyImage,
DepthImage=DummyDepthImage,
components=SimpleNamespace(Colormap=SimpleNamespace(Viridis="viridis")),
log=dummy_log,
send_blueprint=dummy_send_blueprint,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
blueprint=dummy_rrb,
)
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
# Now import and reload the module under test, to bind to our rerun mock
import lerobot.utils.rerun_visualization as rv
importlib.reload(rv)
# Expose the reloaded module, the call recorder and the captured blueprints
yield rv, calls, blueprints
def _keys(calls):
"""Helper to extract just the keys logged to rr.log"""
return [k for (k, _obj, _kw) in calls]
def _obj_for(calls, key):
"""Find the first object logged under a given key."""
for k, obj, _kw in calls:
if k == key:
return obj
raise KeyError(f"Key {key} not found in calls: {calls}")
def _kwargs_for(calls, key):
for k, _obj, kw in calls:
if k == key:
return kw
raise KeyError(f"Key {key} not found in calls: {calls}")
def _views_by_kind(blueprint, kind):
"""Return the views of a given kind from the (single) blueprint's grid."""
return [v for v in blueprint.root.views if v.kind == kind]
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
rv, calls, blueprints = mock_rerun
# Build EnvTransition dict
obs = {
f"{OBS_STATE}.temperature": np.float32(25.0),
# CHW image should be converted to HWC for rr.Image
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
}
act = {
"action.throttle": 0.7,
# 1D array should be logged as a single Scalars batch under one entity path
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
}
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: act,
}
# Extract observation and action data from transition like in the real call sites
obs_data = transition.get(TransitionKey.OBSERVATION, {})
action_data = transition.get(TransitionKey.ACTION, {})
rv.log_rerun_data(observation=obs_data, action=action_data)
# We expect:
# - observation.state.temperature -> Scalars
# - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars
# - action.vector -> single Scalars batch (no per-element suffix)
expected_keys = {
f"{OBS_STATE}.temperature",
"observation.camera",
"action.throttle",
"action.vector",
}
assert set(_keys(calls)) == expected_keys
# Check scalar types and values
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar"
assert float(temp_obj.value) == pytest.approx(25.0)
throttle_obj = _obj_for(calls, "action.throttle")
assert type(throttle_obj).__name__ == "DummyScalar"
assert float(throttle_obj.value) == pytest.approx(0.7)
# 1D vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vector")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
# Check image handling: CHW -> HWC
img_obj = _obj_for(calls, "observation.camera")
assert type(img_obj).__name__ == "DummyImage"
assert img_obj.arr.shape == (10, 20, 3) # transposed
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
# A blueprint should have been built and sent exactly once, and cached on the function.
assert len(blueprints) == 1
assert rv.log_rerun_data.blueprint is blueprints[0]
bp = blueprints[0]
# One spatial view per image path
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.camera"}
# One time-series view each for observation and action scalars
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert set(ts_views) == {"observation", "action"}
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
rv, calls, blueprints = mock_rerun
# First dict without prefixes treated as observation
# Second dict without prefixes treated as action
obs_plain = {
"temp": 1.5,
# Already HWC image => should stay as-is
"img": np.zeros((5, 6, 3), dtype=np.uint8),
"none": None, # should be skipped
}
act_plain = {
"throttle": 0.3,
"vec": np.array([9, 8, 7], dtype=np.float32),
}
# Extract observation and action data from list like the old function logic did
# First dict was treated as observation, second as action
rv.log_rerun_data(observation=obs_plain, action=act_plain)
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
expected = {
"observation.temp",
"observation.img",
"action.throttle",
"action.vec",
}
logged = set(_keys(calls))
assert logged == expected
# Scalars
t = _obj_for(calls, "observation.temp")
assert type(t).__name__ == "DummyScalar"
assert float(t.value) == pytest.approx(1.5)
throttle = _obj_for(calls, "action.throttle")
assert type(throttle).__name__ == "DummyScalar"
assert float(throttle.value) == pytest.approx(0.3)
# Image stays HWC
img = _obj_for(calls, "observation.img")
assert type(img).__name__ == "DummyImage"
assert img.arr.shape == (5, 6, 3)
assert _kwargs_for(calls, "observation.img").get("static", False) is True
# Vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vec")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
# Blueprint sent once with the expected view layout
assert len(blueprints) == 1
bp = blueprints[0]
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.img"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
def test_log_rerun_data_kwargs_only(mock_rerun):
rv, calls, blueprints = mock_rerun
rv.log_rerun_data(
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
action={"action.a": 1.0},
)
keys = set(_keys(calls))
assert "observation.temp" in keys
assert "observation.gray" in keys
assert "action.a" in keys
temp = _obj_for(calls, "observation.temp")
assert type(temp).__name__ == "DummyScalar"
assert float(temp.value) == pytest.approx(10.0)
img = _obj_for(calls, "observation.gray")
assert type(img).__name__ == "DummyDepthImage" # single-channel -> DepthImage
assert img.arr.shape == (8, 8, 1) # remains HWC
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert float(a.value) == pytest.approx(1.0)
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
assert len(blueprints) == 1
bp = blueprints[0]
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.a"]
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
"""The blueprint is built from the first call and not resent on subsequent calls."""
rv, calls, blueprints = mock_rerun
rv.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
assert len(blueprints) == 1
first_blueprint = rv.log_rerun_data.blueprint
rv.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
# Still only one blueprint, and the cached one is unchanged.
assert len(blueprints) == 1
assert rv.log_rerun_data.blueprint is first_blueprint
+13 -287
View File
@@ -14,297 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import sys
from types import SimpleNamespace
"""Tests for the backend-agnostic visualization dispatch.
These exercise the display-mode routing/validation only; they need neither ``rerun`` nor
``foxglove`` installed since the unknown-mode branch raises before touching any backend. Backend
behavior is covered in ``test_rerun_visualization.py`` and ``test_foxglove_visualization.py``.
"""
import numpy as np
import pytest
pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])")
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
from lerobot.utils import visualization_utils as vu
@pytest.fixture
def mock_rerun(monkeypatch):
"""
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
depend on the real library. Also reload the module-under-test so it binds to
this mock `rr`.
"""
calls = []
blueprints = []
class DummyScalar:
def __init__(self, value):
# Scalars may be built from a single float or from a 1D array batch.
self.value = value
class DummyImage:
def __init__(self, arr):
self.arr = arr
def compress(self, *a, **k):
return self
class DummyDepthImage:
def __init__(self, arr, colormap=None):
self.arr = arr
self.colormap = colormap
def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs:
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
def dummy_send_blueprint(blueprint, *a, **k):
blueprints.append(blueprint)
# Mock the `rerun.blueprint` submodule used to build the layout.
dummy_rrb = SimpleNamespace(
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
kind="Spatial2DView", origin=origin, name=name
),
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
kind="TimeSeriesView", name=name, contents=contents
),
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
)
dummy_rr = SimpleNamespace(
__name__="rerun",
__package__="rerun",
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
Image=DummyImage,
DepthImage=DummyDepthImage,
components=SimpleNamespace(Colormap=SimpleNamespace(Viridis="viridis")),
log=dummy_log,
send_blueprint=dummy_send_blueprint,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
blueprint=dummy_rrb,
)
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
# Now import and reload the module under test, to bind to our rerun mock
import lerobot.utils.visualization_utils as vu
importlib.reload(vu)
# Expose the reloaded module, the call recorder and the captured blueprints
yield vu, calls, blueprints
def test_visualization_modes():
assert vu.VISUALIZATION_MODES == ("rerun", "foxglove")
def _keys(calls):
"""Helper to extract just the keys logged to rr.log"""
return [k for (k, _obj, _kw) in calls]
def _obj_for(calls, key):
"""Find the first object logged under a given key."""
for k, obj, _kw in calls:
if k == key:
return obj
raise KeyError(f"Key {key} not found in calls: {calls}")
def _kwargs_for(calls, key):
for k, _obj, kw in calls:
if k == key:
return kw
raise KeyError(f"Key {key} not found in calls: {calls}")
def _views_by_kind(blueprint, kind):
"""Return the views of a given kind from the (single) blueprint's grid."""
return [v for v in blueprint.root.views if v.kind == kind]
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
vu, calls, blueprints = mock_rerun
# Build EnvTransition dict
obs = {
f"{OBS_STATE}.temperature": np.float32(25.0),
# CHW image should be converted to HWC for rr.Image
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
}
act = {
"action.throttle": 0.7,
# 1D array should be logged as a single Scalars batch under one entity path
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
}
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: act,
}
# Extract observation and action data from transition like in the real call sites
obs_data = transition.get(TransitionKey.OBSERVATION, {})
action_data = transition.get(TransitionKey.ACTION, {})
vu.log_rerun_data(observation=obs_data, action=action_data)
# We expect:
# - observation.state.temperature -> Scalars
# - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars
# - action.vector -> single Scalars batch (no per-element suffix)
expected_keys = {
f"{OBS_STATE}.temperature",
"observation.camera",
"action.throttle",
"action.vector",
}
assert set(_keys(calls)) == expected_keys
# Check scalar types and values
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar"
assert float(temp_obj.value) == pytest.approx(25.0)
throttle_obj = _obj_for(calls, "action.throttle")
assert type(throttle_obj).__name__ == "DummyScalar"
assert float(throttle_obj.value) == pytest.approx(0.7)
# 1D vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vector")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
# Check image handling: CHW -> HWC
img_obj = _obj_for(calls, "observation.camera")
assert type(img_obj).__name__ == "DummyImage"
assert img_obj.arr.shape == (10, 20, 3) # transposed
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
# A blueprint should have been built and sent exactly once, and cached on the function.
assert len(blueprints) == 1
assert vu.log_rerun_data.blueprint is blueprints[0]
bp = blueprints[0]
# One spatial view per image path
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.camera"}
# One time-series view each for observation and action scalars
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert set(ts_views) == {"observation", "action"}
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
vu, calls, blueprints = mock_rerun
# First dict without prefixes treated as observation
# Second dict without prefixes treated as action
obs_plain = {
"temp": 1.5,
# Already HWC image => should stay as-is
"img": np.zeros((5, 6, 3), dtype=np.uint8),
"none": None, # should be skipped
}
act_plain = {
"throttle": 0.3,
"vec": np.array([9, 8, 7], dtype=np.float32),
}
# Extract observation and action data from list like the old function logic did
# First dict was treated as observation, second as action
vu.log_rerun_data(observation=obs_plain, action=act_plain)
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
expected = {
"observation.temp",
"observation.img",
"action.throttle",
"action.vec",
}
logged = set(_keys(calls))
assert logged == expected
# Scalars
t = _obj_for(calls, "observation.temp")
assert type(t).__name__ == "DummyScalar"
assert float(t.value) == pytest.approx(1.5)
throttle = _obj_for(calls, "action.throttle")
assert type(throttle).__name__ == "DummyScalar"
assert float(throttle.value) == pytest.approx(0.3)
# Image stays HWC
img = _obj_for(calls, "observation.img")
assert type(img).__name__ == "DummyImage"
assert img.arr.shape == (5, 6, 3)
assert _kwargs_for(calls, "observation.img").get("static", False) is True
# Vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vec")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
# Blueprint sent once with the expected view layout
assert len(blueprints) == 1
bp = blueprints[0]
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.img"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
def test_log_rerun_data_kwargs_only(mock_rerun):
vu, calls, blueprints = mock_rerun
vu.log_rerun_data(
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
action={"action.a": 1.0},
)
keys = set(_keys(calls))
assert "observation.temp" in keys
assert "observation.gray" in keys
assert "action.a" in keys
temp = _obj_for(calls, "observation.temp")
assert type(temp).__name__ == "DummyScalar"
assert float(temp.value) == pytest.approx(10.0)
img = _obj_for(calls, "observation.gray")
assert type(img).__name__ == "DummyDepthImage" # single-channel -> DepthImage
assert img.arr.shape == (8, 8, 1) # remains HWC
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert float(a.value) == pytest.approx(1.0)
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
assert len(blueprints) == 1
bp = blueprints[0]
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.a"]
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
"""The blueprint is built from the first call and not resent on subsequent calls."""
vu, calls, blueprints = mock_rerun
vu.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
assert len(blueprints) == 1
first_blueprint = vu.log_rerun_data.blueprint
vu.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
# Still only one blueprint, and the cached one is unchanged.
assert len(blueprints) == 1
assert vu.log_rerun_data.blueprint is first_blueprint
@pytest.mark.parametrize("func", ["init_visualization", "log_visualization_data", "shutdown_visualization"])
def test_dispatch_rejects_unknown_mode(func):
with pytest.raises(ValueError, match="Unknown display_mode"):
getattr(vu, func)("bogus")
Generated
+25
View File
@@ -1550,6 +1550,26 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/47/c99d5268f354002ce80f8d029cd9d7d872969da1de8b93d32de4dc56d6f4/fonttools-4.63.0-py3-none-any.whl", hash = "sha256:445af2eab030a16b9171ea8bdda7ebf7d96bda2df88ee182a464252f6e05e20d", size = 1164562, upload-time = "2026-05-14T12:04:29.092Z" },
]
[[package]]
name = "foxglove-sdk"
version = "0.25.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c1/a7/86a252782ea0d9baf1357369ad1bbf1ed644768702b0266a3fa3a05361d0/foxglove_sdk-0.25.1.tar.gz", hash = "sha256:8230f3c32ea3ab715818687377491594ec9c7e58e6b0ed8ed91aadf937ce706b", size = 547778, upload-time = "2026-06-02T03:13:18.942Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/58/15/59f02e8201b8da09ce05d8774820c29efc9149862b70ee6b3a27968e791a/foxglove_sdk-0.25.1-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:5af9f9a691eefbe6e0a47875ff2f7d0fc36607f0920e8690bbdc2dfd4fb22451", size = 17911538, upload-time = "2026-06-02T03:13:12.493Z" },
{ url = "https://files.pythonhosted.org/packages/27/ed/16d809fab24cbfdf97c15c9cdd80eabfeb447ca545ede426950d62bac848/foxglove_sdk-0.25.1-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:3e908bd87d1926a05c785779d8252db6b87eef685f284ec1cf46ee501645d08e", size = 16452309, upload-time = "2026-06-02T03:13:10.607Z" },
{ url = "https://files.pythonhosted.org/packages/d6/c3/f95874935a3436841487df1f0202de4d20eabc0adb6b79c94c531bbe7eb3/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:968e32c8668d172f6b546c8e7af658ed35a21ec165adc3bacf53a04dda159f12", size = 2355680, upload-time = "2026-06-02T02:34:01.668Z" },
{ url = "https://files.pythonhosted.org/packages/38/da/ad22d8d6e3fedde9fc0c49aa8b20394e5e0bc44ab3fba564c77a64ddc7e2/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3f75374fedafe259c40b19bc645589d9453708eab679a5b07c603035f936d29a", size = 2274075, upload-time = "2026-06-02T02:34:07.212Z" },
{ url = "https://files.pythonhosted.org/packages/a3/fa/1254adb5e72eff507695473e9c82d0e90395b61463e5353762250db30d3d/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b6d3af517a00342bf7b08a4a65b043f3eafaa197138752b6fbd704fb91043fa", size = 2282160, upload-time = "2026-06-02T02:34:08.812Z" },
{ url = "https://files.pythonhosted.org/packages/7c/e4/2b22ef06ba4058494c7aa35974d138f8f1ae4cf5273f77d69c9dc3a99b45/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:aed27c0f03a45fd6abdd566498bfee2672391602bcff32c827b8e3a6d8f67ab1", size = 22685338, upload-time = "2026-06-02T02:34:04.688Z" },
{ url = "https://files.pythonhosted.org/packages/35/7c/58324c99b80eef0b674c8d4f5c2e07c66fd1480a27a8f0d4d79371805111/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:419dd8308e3f91e2ae487b727f1bf1804642990876163b2a353db4a1b1de1425", size = 19326096, upload-time = "2026-06-02T02:34:10.939Z" },
{ url = "https://files.pythonhosted.org/packages/fe/9c/3452d92959e05fc6b1c1e5f032605d55623aeb6704357d20408f8781bc84/foxglove_sdk-0.25.1-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0fcb36e628ab3d9043e193f12ad4dbbb955fe18616aac7ef5bca82c52910f108", size = 2539020, upload-time = "2026-06-02T03:13:14.365Z" },
{ url = "https://files.pythonhosted.org/packages/b5/af/57fa58525d3acb5c5480a6f0ef86450b1a0ccae2b21248edb1376073ce55/foxglove_sdk-0.25.1-cp310-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:7909fd9f94935935dd8813702d84ffdbfebeb3866673c618ce35e8cfedd03029", size = 2550999, upload-time = "2026-06-02T03:13:15.715Z" },
{ url = "https://files.pythonhosted.org/packages/90/78/f74bb167186c965d475ff360fa6eb7441d5ac6c6239d60f542f63984f849/foxglove_sdk-0.25.1-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:69d5966213b5212b8841b4004fe582db924a74f1610d8452ad890f6931702926", size = 2560166, upload-time = "2026-06-02T03:13:17.254Z" },
{ url = "https://files.pythonhosted.org/packages/81/83/1c4c6d04fbd4784fe44fb2da021db1adf1f03a371f1e5679a383c1173235/foxglove_sdk-0.25.1-cp310-abi3-win32.whl", hash = "sha256:2a1121a5c74590ff6e61628c4a46dc57d392d290b4beeb29d6852933da56224a", size = 1618124, upload-time = "2026-06-02T03:13:20.158Z" },
{ url = "https://files.pythonhosted.org/packages/5f/4d/bdb9e252a41a951eb53908ac9cb965b7480c3ba649174f5398d4fcf0ca1d/foxglove_sdk-0.25.1-cp310-abi3-win_amd64.whl", hash = "sha256:6ed3ad0d3e72cd7875e7e293709c5ff90494fe14f1b48a336baffc313a7272cc", size = 16588452, upload-time = "2026-06-02T03:13:21.636Z" },
]
[[package]]
name = "fqdn"
version = "1.5.1"
@@ -2811,6 +2831,7 @@ all = [
{ name = "faker" },
{ name = "fastapi" },
{ name = "feetech-servo-sdk" },
{ name = "foxglove-sdk" },
{ name = "grpcio" },
{ name = "grpcio-tools" },
{ name = "gym-aloha" },
@@ -2895,6 +2916,7 @@ core-scripts = [
{ name = "av" },
{ name = "datasets" },
{ name = "deepdiff" },
{ name = "foxglove-sdk" },
{ name = "jsonlines" },
{ name = "pandas" },
{ name = "pyarrow" },
@@ -2917,6 +2939,7 @@ dataset = [
dataset-viz = [
{ name = "av" },
{ name = "datasets" },
{ name = "foxglove-sdk" },
{ name = "jsonlines" },
{ name = "pandas" },
{ name = "pyarrow" },
@@ -3187,6 +3210,7 @@ video-benchmark = [
{ name = "scikit-image" },
]
viz = [
{ name = "foxglove-sdk" },
{ name = "rerun-sdk" },
]
vla-jepa = [
@@ -3226,6 +3250,7 @@ requires-dist = [
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
{ name = "foxglove-sdk", marker = "extra == 'viz'", specifier = ">=0.25.1,<0.26.0" },
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = ">=1.73.1,<2.0.0" },
{ name = "grpcio", marker = "extra == 'reachy2'", specifier = "<=1.73.1" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.73.1,<2.0.0" },