Compare commits

..

1 Commits

Author SHA1 Message Date
CarolinePascal d7be868dfa feat(libaom-av1): adding support for libaom-av1 codec 2026-06-30 17:57:17 +02:00
27 changed files with 558 additions and 1640 deletions
+5 -5
View File
@@ -126,7 +126,7 @@ import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.utils.visualization_utils import init_visualization, log_visualization_data, shutdown_visualization
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
@@ -142,7 +142,7 @@ teleop_config = SO101LeaderConfig(
id="my_leader_arm",
)
init_visualization("rerun", session_name="teleoperation") # pass "foxglove" to stream to Foxglove instead
init_rerun(session_name="teleoperation")
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
@@ -158,7 +158,7 @@ while True:
observation = robot.get_observation()
action = teleop_device.get_action()
robot.send_action(action)
log_visualization_data("rerun", observation=observation, action=action)
log_rerun_data(observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
@@ -223,7 +223,7 @@ from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_visualization
from lerobot.utils.visualization_utils import init_rerun
from lerobot.scripts.lerobot_record import record_loop
from lerobot.processor import make_default_processors
@@ -270,7 +270,7 @@ def main():
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
init_visualization("rerun", session_name="recording")
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
-2
View File
@@ -265,8 +265,6 @@ lerobot-dataset-viz \
Once executed, the tool opens `rerun.io` and displays the camera streams, robot states, and actions for the selected episode.
To use [Foxglove](https://foxglove.dev) instead of Rerun, install the extra add `--display-mode foxglove`. This starts a WebSocket server (connect the Foxglove app to `ws://127.0.0.1:8765`) that serves the episode as a seekable timeline you can play/pause and scrub.
For advanced usage—including visualizing datasets stored on a remote server—run:
```bash
-1
View File
@@ -125,7 +125,6 @@ hardware = [
]
viz = [
"rerun-sdk>=0.24.0,<0.34.0",
"foxglove-sdk>=0.25.1,<0.26.0",
]
# ── User-facing composite extras (map to CLI scripts) ─────
# lerobot-record, lerobot-replay, lerobot-calibrate, lerobot-teleoperate, etc.
+9 -1
View File
@@ -36,7 +36,9 @@ HW_VIDEO_CODECS = [
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
{"h264", "hevc", "libsvtav1", "libaom-av1", "auto", *HW_VIDEO_CODECS}
)
# Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
@@ -220,6 +222,12 @@ class VideoEncoderConfig:
if self.fast_decode:
opts["tune"] = "fastdecode"
set_if("threads", encoder_threads)
elif self.vcodec == "libaom-av1":
set_if("crf", self.crf)
set_if("preset", self.preset)
if encoder_threads is not None:
opts["threads"] = encoder_threads
opts["row-mt"] = 1
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
if self.crf is not None:
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
+29 -27
View File
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import builtins
import dataclasses
@@ -21,7 +19,7 @@ import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, TypedDict, TypeVar, Unpack
from typing import TypedDict, TypeVar, Unpack
import packaging
import safetensors
@@ -40,13 +38,10 @@ from .utils import log_model_loading_keys
T = TypeVar("T", bound="PreTrainedPolicy")
if TYPE_CHECKING:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
def _build_card_context(
cfg: TrainPipelineConfig | None,
dataset_meta: LeRobotDatasetMetadata | None,
dataset_repo_id: str | None,
input_features: dict | None,
output_features: dict | None,
) -> dict:
@@ -77,16 +72,30 @@ def _build_card_context(
"lerobot_version": __version__,
}
if dataset_meta is not None:
context["dataset"] = {
"repo_id": dataset_meta.repo_id,
"episodes": dataset_meta.total_episodes,
"frames": dataset_meta.total_frames,
"fps": dataset_meta.fps,
"tasks": [str(task) for task in dataset_meta.tasks.index],
}
context["robot_type"] = dataset_meta.robot_type
context["cameras"] = [key.split(".")[-1] for key in dataset_meta.camera_keys]
if dataset_repo_id:
dataset_cfg = getattr(cfg, "dataset", None)
try:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(
dataset_repo_id,
root=getattr(dataset_cfg, "root", None),
revision=getattr(dataset_cfg, "revision", None),
)
context["dataset"] = {
"repo_id": dataset_repo_id,
"episodes": meta.total_episodes,
"frames": meta.total_frames,
"fps": meta.fps,
"tasks": [str(task) for task in meta.tasks.index],
}
context["robot_type"] = meta.robot_type
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
logging.warning(
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
f"omitted from the model card. ({e})"
)
return context
@@ -295,7 +304,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
cfg: TrainPipelineConfig,
peft_model=None,
state_dict: dict[str, Tensor] | None = None,
dataset_meta: LeRobotDatasetMetadata | None = None,
):
api = HfApi()
repo_id = api.create_repo(
@@ -317,12 +325,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self.save_pretrained(saved_path, state_dict=state_dict)
card = self.generate_model_card(
cfg.dataset.repo_id,
self.config.type,
self.config.license,
self.config.tags,
cfg=cfg,
dataset_meta=dataset_meta,
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
)
card.save(str(saved_path / "README.md"))
@@ -349,7 +352,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
license: str | None,
tags: list[str] | None,
cfg: TrainPipelineConfig | None = None,
dataset_meta: LeRobotDatasetMetadata | None = None,
) -> ModelCard:
base_model_mapping = {
"smolvla": "lerobot/smolvla_base",
@@ -370,7 +372,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
)
context = _build_card_context(
cfg, dataset_meta, self.config.input_features, self.config.output_features
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
)
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
@@ -387,7 +389,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self,
peft_config=None,
peft_cli_overrides: dict | None = None,
) -> PreTrainedPolicy:
) -> "PreTrainedPolicy":
"""
Wrap this policy with PEFT adapters for parameter-efficient fine-tuning.
@@ -65,13 +65,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
cameras=left_arm_cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
control_mode=config.left_arm_config.control_mode,
mit_kp=config.left_arm_config.mit_kp,
mit_kd=config.left_arm_config.mit_kd,
gripper_control_mode=config.left_arm_config.gripper_control_mode,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.left_arm_config.gripper_mit_kp,
gripper_mit_kd=config.left_arm_config.gripper_mit_kd,
joint_limits=config.left_arm_config.joint_limits,
)
@@ -86,13 +80,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
cameras=config.right_arm_config.cameras,
motor_can_ids=config.right_arm_config.motor_can_ids,
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
control_mode=config.right_arm_config.control_mode,
mit_kp=config.right_arm_config.mit_kp,
mit_kd=config.right_arm_config.mit_kd,
gripper_control_mode=config.right_arm_config.gripper_control_mode,
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.right_arm_config.gripper_mit_kp,
gripper_mit_kd=config.right_arm_config.gripper_mit_kd,
joint_limits=config.right_arm_config.joint_limits,
)
@@ -65,33 +65,18 @@ class RebotB601FollowerConfig:
}
)
# Max speed (deg/s) per joint for POS_VEL arms and FORCE_POS gripper (motor order).
pos_vel_velocity: float | list[float] = field(
default_factory=lambda: [150.0, 150.0, 150.0, 150.0, 150.0, 150.0, 900.0]
)
# Target velocity for joints running in POS_VEL mode, in degrees/s. A scalar is
# applied to every joint; a list provides one value per joint (in motor order).
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
# Arm control: "mit" or "pos_vel".
control_mode: str = "mit"
# MIT kp/kd per arm joint (motor order). Unused when control_mode="pos_vel".
mit_kp: float | list[float] = field(default_factory=lambda: [45.0, 45.0, 45.0, 8.0, 9.0, 8.0, 8.0])
mit_kd: float | list[float] = field(default_factory=lambda: [12.0, 12.0, 12.0, 1.0, 1.0, 1.0, 1.0])
# Gripper control: "force_pos" or "mit".
gripper_control_mode: str = "force_pos"
# FORCE_POS only: max grip force, in [0, 1].
gripper_torque_ratio: float = 0.07
# MIT only.
gripper_mit_kp: float = 8.0
gripper_mit_kd: float = 0.3
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
gripper_torque_ratio: float = 0.1
# Soft joint limits (degrees). These are clipped against on every action.
joint_limits: dict[str, tuple[float, float]] = field(
default_factory=lambda: {
"shoulder_pan": (-150.0, 150.0),
"shoulder_lift": (-200.0, 1.0),
"shoulder_pan": (-145.0, 145.0),
"shoulder_lift": (-170.0, 1.0),
"elbow_flex": (-200.0, 1.0),
"wrist_flex": (-80.0, 90.0),
"wrist_yaw": (-90.0, 90.0),
@@ -174,25 +174,11 @@ class RebotB601Follower(Robot):
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
if self.config.control_mode not in ("pos_vel", "mit"):
raise ValueError(
f"Unsupported control_mode '{self.config.control_mode}'. Use 'pos_vel' or 'mit'."
)
if self.config.gripper_control_mode not in ("force_pos", "mit"):
raise ValueError(
f"Unsupported gripper_control_mode '{self.config.gripper_control_mode}'. "
"Use 'force_pos' or 'mit'."
)
use_mit = self.config.control_mode == "mit"
gripper_use_mit = self.config.gripper_control_mode == "mit"
self.bus.enable_all()
for motor_name, motor in self.motors.items():
if motor_name == GRIPPER_MOTOR:
target_mode = MotorBridgeMode.MIT if gripper_use_mit else MotorBridgeMode.FORCE_POS
elif use_mit:
target_mode = MotorBridgeMode.MIT
else:
target_mode = MotorBridgeMode.POS_VEL
target_mode = (
MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else MotorBridgeMode.POS_VEL
)
for attempt in range(_ENSURE_MODE_RETRIES + 1):
try:
motor.ensure_mode(target_mode)
@@ -278,34 +264,22 @@ class RebotB601Follower(Robot):
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
use_mit = self.config.control_mode == "mit"
for motor_name, position_deg in goal_pos.items():
motor = self.motors.get(motor_name)
if motor is None:
continue
idx = self.motor_names.index(motor_name)
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
pos_rad = math.radians(position_deg)
vel_rad = math.radians(vel_deg_s)
if motor_name == GRIPPER_MOTOR:
if self.config.gripper_control_mode == "mit":
motor.send_mit(pos_rad, 0.0, self.config.gripper_mit_kp, self.config.gripper_mit_kd, 0.0)
else:
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_force_pos(pos_rad, math.radians(vel_deg_s), self.config.gripper_torque_ratio)
elif use_mit:
kp = self.config.mit_kp[idx] if isinstance(self.config.mit_kp, list) else self.config.mit_kp
kd = self.config.mit_kd[idx] if isinstance(self.config.mit_kd, list) else self.config.mit_kd
motor.send_mit(pos_rad, 0.0, kp, kd, 0.0)
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
else:
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_pos_vel(pos_rad, math.radians(vel_deg_s))
motor.send_pos_vel(pos_rad, vel_rad)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
+3 -6
View File
@@ -226,14 +226,11 @@ class RolloutConfig:
device: str | None = None
task: str = ""
display_data: bool = False
# Visualization backend used when display_data is True: "rerun" or "foxglove".
display_mode: str = "rerun"
# For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket
# server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces).
# Display data on a remote Rerun server
display_ip: str | None = None
# For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to.
# Port of the remote Rerun server
display_port: int | None = None
# Whether to display compressed (JPEG) images instead of raw frames
# Whether to display compressed images in Rerun
display_compressed_images: bool = False
# Use vocal synthesis to read events
play_sounds: bool = True
+3 -4
View File
@@ -26,7 +26,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import log_visualization_data
from lerobot.utils.visualization_utils import log_rerun_data
from ..inference import InferenceEngine
@@ -162,12 +162,11 @@ class RolloutStrategy(abc.ABC):
action_dict: dict | None,
runtime_ctx: RuntimeContext,
) -> None:
"""Log observation/action telemetry to the visualization backend if display_data is enabled."""
"""Log observation/action telemetry to Rerun if display_data is enabled."""
cfg = runtime_ctx.cfg
if not cfg.display_data:
return
log_visualization_data(
cfg.display_mode,
log_rerun_data(
observation=obs_processed,
action=action_dict,
compress_images=cfg.display_compressed_images,
+2 -5
View File
@@ -44,7 +44,7 @@ from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import log_visualization_data
from lerobot.utils.visualization_utils import log_rerun_data
from ..configs import EpisodicStrategyConfig
from ..context import RolloutContext
@@ -171,7 +171,6 @@ class EpisodicStrategy(RolloutStrategy):
fps=fps,
control_time_s=reset_time_s,
display_data=cfg.display_data,
display_mode=cfg.display_mode,
display_compressed=display_compressed,
)
@@ -260,7 +259,6 @@ class EpisodicStrategy(RolloutStrategy):
fps: float,
control_time_s: float,
display_data: bool,
display_mode: str,
display_compressed: bool,
) -> None:
"""Reset-phase loop: teleop drives the robot if available, no recording."""
@@ -290,8 +288,7 @@ class EpisodicStrategy(RolloutStrategy):
if display_data:
obs_processed = processors.robot_observation_processor(obs)
log_visualization_data(
display_mode,
log_rerun_data(
observation=obs_processed,
action=act_teleop,
compress_images=display_compressed,
+32 -100
View File
@@ -59,18 +59,6 @@ distant$ lerobot-dataset-viz \
local$ rerun rerun+http://IP:GRPC_PORT/proxy
```
- Visualize data in Foxglove with a seekable, scrubbable timeline:
```
local$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--display-mode foxglove
# then open the Foxglove app and connect to ws://127.0.0.1:8765
```
This starts a Foxglove WebSocket server that serves the episode on demand from the on-disk dataset,
so you can play/pause and scrub anywhere in the episode using Foxglove's playback controls.
"""
import argparse
@@ -85,12 +73,9 @@ import torch.utils.data
import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD, SUCCESS
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.utils import init_logging
DEFAULT_FOXGLOVE_PORT = 8765
DEFAULT_RERUN_PORT = 9090
def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]:
"""Return per-dimension names for a feature from the dataset metadata.
@@ -123,12 +108,6 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
return hwc_uint8_numpy
def to_hwc_float32_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
check_chw_float32(chw_float32_torch)
hwc_float32_numpy = chw_float32_torch.permute(1, 2, 0).numpy()
return hwc_float32_numpy
def build_blueprint_from_dataset(dataset: LeRobotDataset):
"""Build a Rerun blueprint laying out camera images and time series for the given dataset.
@@ -147,43 +126,32 @@ def build_blueprint_from_dataset(dataset: LeRobotDataset):
names = get_feature_names(dataset, key)
styling = rr.SeriesLines(names=names)
views.append(rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling}))
for key in (DONE, REWARD, SUCCESS):
for key in (DONE, REWARD, "next.success"):
if key in dataset.features:
views.append(rrb.TimeSeriesView(origin=key, name=key))
return rrb.Blueprint(rrb.Grid(*views))
def to_hwc_uint16_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
check_chw_float32(chw_float32_torch)
hwc_uint16_numpy = chw_float32_torch.round().type(torch.uint16).permute(1, 2, 0).numpy()
return hwc_uint16_numpy
def visualize_dataset(
dataset: LeRobotDataset,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
mode: str = "local",
web_port: int | None = None,
web_port: int = 9090,
grpc_port: int = 9876,
save: bool = False,
output_dir: Path | None = None,
display_compressed_images: bool = False,
display_mode: str = "rerun",
host: str = "127.0.0.1",
autoplay: bool = True,
**kwargs,
) -> Path | None:
if display_mode == "foxglove":
from lerobot.utils.foxglove_visualization import serve_foxglove_dataset_playback
logging.info("Starting Foxglove server")
serve_foxglove_dataset_playback(
dataset,
episode_index,
host=host,
port=web_port if web_port is not None else DEFAULT_FOXGLOVE_PORT,
compress_images=display_compressed_images,
autoplay=autoplay,
)
return None
if save:
assert output_dir is not None, (
"Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
@@ -220,20 +188,14 @@ def visualize_dataset(
if mode == "distant":
server_uri = rr.serve_grpc(grpc_port=grpc_port)
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
rr.serve_web_viewer(
open_browser=False,
web_port=web_port if web_port is not None else DEFAULT_RERUN_PORT,
connect_to=server_uri,
)
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
logging.info("Logging to Rerun")
# Use the dataset's q01/q99 depth statistics for robust depth range bounds
depth_ranges = {}
for key in dataset.meta.depth_keys:
stats = (dataset.meta.stats or {}).get(key)
if not stats:
continue
stats = dataset.meta.stats[key]
lo = stats["q01"] if "q01" in stats else stats["min"]
hi = stats["q99"] if "q99" in stats else stats["max"]
depth_ranges[key] = (float(np.asarray(lo).item()), float(np.asarray(hi).item()))
@@ -251,11 +213,11 @@ def visualize_dataset(
# display each camera image (or depth map)
for key in dataset.meta.camera_keys:
if key in dataset.meta.depth_keys:
depth = to_hwc_float32_numpy(batch[key][i])
depth = to_hwc_uint16_numpy(batch[key][i])
depth_entity = rr.DepthImage(
depth,
colormap=rr.components.Colormap.Viridis,
depth_range=depth_ranges.get(key),
depth_range=depth_ranges[key],
)
rr.log(key, entity=depth_entity)
else:
@@ -277,8 +239,8 @@ def visualize_dataset(
if REWARD in batch:
rr.log(REWARD, rr.Scalars(batch[REWARD][i].item()))
if SUCCESS in batch:
rr.log(SUCCESS, rr.Scalars(batch[SUCCESS][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalars(batch["next.success"][i].item()))
# save .rrd locally
if mode == "local" and save:
@@ -350,11 +312,13 @@ def main():
parser.add_argument(
"--web-port",
type=int,
default=None,
help=(
"Web/WebSocket port. For rerun `--mode distant` it is the web viewer port (default 9090); "
"for `--display-mode foxglove` it is the server bind port (default 8765)."
),
default=9090,
help="Web port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--ws-port",
type=int,
help="deprecated, please use --grpc-port instead.",
)
parser.add_argument(
"--grpc-port",
@@ -387,56 +351,24 @@ def main():
parser.add_argument(
"--display-compressed-images",
action="store_true",
help="If set, display compressed (JPEG) images instead of uncompressed ones.",
)
parser.add_argument(
"--display-mode",
type=str,
default="rerun",
choices=["rerun", "foxglove"],
help=(
"Visualization backend. 'rerun' uses the Rerun viewer (--mode/--save/--*-port apply). "
"'foxglove' starts a Foxglove WebSocket server that serves the episode as a seekable, "
"scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--web-port)."
),
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help=(
"Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set "
"(127.0.0.1 for local only, 0.0.0.0 for all interfaces)."
),
)
parser.add_argument(
"--no-autoplay",
dest="autoplay",
action="store_false",
help=(
"For `--display-mode foxglove`: don't start playing automatically when a client "
"connects; wait for play to be pressed in the Foxglove app instead."
),
help="If set, display compressed images in Rerun instead of uncompressed ones.",
)
args = parser.parse_args()
if args.display_mode == "foxglove":
rerun_only = ("mode", "save", "output_dir", "grpc_port", "batch_size", "num_workers")
ignored = [name for name in rerun_only if getattr(args, name) != parser.get_default(name)]
if ignored:
logging.warning(
"These flags only apply to `--display-mode rerun` and are ignored with "
"`--display-mode foxglove`: %s.",
", ".join(f"--{name.replace('_', '-')}" for name in ignored),
)
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
tolerance_s = kwargs.pop("tolerance_s")
if kwargs["ws_port"] is not None:
logging.warning(
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
)
logging.warning("Setting grpc_port to ws_port value.")
kwargs["grpc_port"] = kwargs.pop("ws_port")
else:
kwargs.pop("ws_port") # Always remove ws_port from kwargs
init_logging()
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
+7 -28
View File
@@ -38,9 +38,6 @@ lerobot-record \\
--display_data=true
```
To stream the data to Foxglove instead of Rerun, add ``--display_mode=foxglove`` (then connect the
Foxglove app to ``ws://127.0.0.1:8765``; override the port with ``--display_port=<port>``).
Example recording with bimanual so100:
```shell
lerobot-record \\
@@ -160,11 +157,7 @@ from lerobot.utils.utils import (
init_logging,
log_say,
)
from lerobot.utils.visualization_utils import (
init_visualization,
log_visualization_data,
shutdown_visualization,
)
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@dataclass
@@ -175,14 +168,11 @@ class RecordConfig:
teleop: TeleoperatorConfig | None = None
# Display all cameras on screen
display_data: bool = False
# Visualization backend used when display_data is True: "rerun" or "foxglove".
display_mode: str = "rerun"
# For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket
# server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces).
# Display data on a remote Rerun server
display_ip: str | None = None
# For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to.
# Port of the remote Rerun server
display_port: int | None = None
# Whether to display compressed (JPEG) images instead of raw frames
# Whether to display compressed images in Rerun
display_compressed_images: bool = False
# Use vocal synthesis to read events.
play_sounds: bool = True
@@ -243,7 +233,6 @@ def record_loop(
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
display_mode: str = "rerun",
display_compressed_images: bool = False,
):
if dataset is not None and dataset.fps != fps:
@@ -338,11 +327,8 @@ def record_loop(
dataset.add_frame(frame)
if display_data:
log_visualization_data(
display_mode,
observation=obs_processed,
action=action_values,
compress_images=display_compressed_images,
log_rerun_data(
observation=obs_processed, action=action_values, compress_images=display_compressed_images
)
dt_s = time.perf_counter() - start_loop_t
@@ -368,9 +354,7 @@ def record(
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
init_visualization(
cfg.display_mode, session_name="recording", ip=cfg.display_ip, port=cfg.display_port
)
init_rerun(session_name="recording", ip=cfg.display_ip, port=cfg.display_port)
display_compressed_images = (
True
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
@@ -480,7 +464,6 @@ def record(
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
display_mode=cfg.display_mode,
display_compressed_images=display_compressed_images,
)
@@ -502,7 +485,6 @@ def record(
control_time_s=cfg.dataset.reset_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
display_mode=cfg.display_mode,
)
if events["rerecord_episode"]:
@@ -528,9 +510,6 @@ def record(
if listener is not None:
listener.stop()
if cfg.display_data:
shutdown_visualization(cfg.display_mode)
if cfg.dataset.push_to_hub:
if dataset and dataset.num_episodes > 0:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
+3 -13
View File
@@ -145,9 +145,6 @@ Usage examples
--dataset.rgb_encoder.vcodec=h264 \\
--dataset.rgb_encoder.preset=fast \\
--dataset.rgb_encoder.extra_options={"tune": "film", "profile:v": "high", "bf": 2}
# Stream to Foxglove instead of Rerun:
# add --display_mode=foxglove, then connect the Foxglove app to ws://127.0.0.1:8765.
"""
import logging
@@ -193,7 +190,7 @@ from lerobot.teleoperators import ( # noqa: F401
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
from lerobot.utils.visualization_utils import init_visualization, shutdown_visualization
from lerobot.utils.visualization_utils import init_rerun
logger = logging.getLogger(__name__)
@@ -204,13 +201,8 @@ def rollout(cfg: RolloutConfig):
init_logging()
if cfg.display_data:
logger.info(
"Initializing %s visualization (ip=%s, port=%s)",
cfg.display_mode,
cfg.display_ip,
cfg.display_port,
)
init_visualization(cfg.display_mode, session_name="rollout", ip=cfg.display_ip, port=cfg.display_port)
logger.info("Initializing Rerun visualization (ip=%s, port=%s)", cfg.display_ip, cfg.display_port)
init_rerun(session_name="rollout", ip=cfg.display_ip, port=cfg.display_port)
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
@@ -235,8 +227,6 @@ def rollout(cfg: RolloutConfig):
logger.info("Interrupted by user")
finally:
strategy.teardown(ctx)
if cfg.display_data:
shutdown_visualization(cfg.display_mode)
logger.info("Rollout finished")
+9 -39
View File
@@ -31,22 +31,6 @@ lerobot-teleoperate \
--display_data=true
```
To stream the data to Foxglove instead of Rerun, add ``--display_mode=foxglove``
(then connect the Foxglove app to ``ws://127.0.0.1:8765``; override the port with ``--display_port=<port>``):
```shell
lerobot-teleoperate \
--robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--robot.id=black \
--teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue \
--display_data=true \
--display_mode=foxglove
```
Example teleoperation with bimanual so100:
```shell
@@ -124,11 +108,7 @@ from lerobot.teleoperators import ( # noqa: F401
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, move_cursor_up
from lerobot.utils.visualization_utils import (
init_visualization,
log_visualization_data,
shutdown_visualization,
)
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
@dataclass
@@ -141,14 +121,11 @@ class TeleoperateConfig:
teleop_time_s: float | None = None
# Display all cameras on screen
display_data: bool = False
# Visualization backend used when display_data is True: "rerun" or "foxglove".
display_mode: str = "rerun"
# For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket
# server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces).
# Display data on a remote Rerun server
display_ip: str | None = None
# For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to.
# Port of the remote Rerun server
display_port: int | None = None
# Whether to display compressed (JPEG) images instead of raw frames
# Whether to display compressed images in Rerun
display_compressed_images: bool = False
@@ -160,7 +137,6 @@ def teleop_loop(
robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction],
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation],
display_data: bool = False,
display_mode: str = "rerun",
duration: float | None = None,
display_compressed_images: bool = False,
):
@@ -173,10 +149,8 @@ def teleop_loop(
teleop: The teleoperator device instance providing control actions.
robot: The robot instance being controlled.
fps: The target frequency for the control loop in frames per second.
display_data: If True, fetches robot observations and displays them in the console and the
visualization backend.
display_mode: Visualization backend to use when display_data is True ("rerun" or "foxglove").
display_compressed_images: If True, compresses images before sending them to the backend for display.
display_data: If True, fetches robot observations and displays them in the console and Rerun.
display_compressed_images: If True, compresses images before sending them to Rerun for display.
duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
@@ -213,8 +187,7 @@ def teleop_loop(
# Process robot observation through pipeline
obs_transition = robot_observation_processor(obs)
log_visualization_data(
display_mode,
log_rerun_data(
observation=obs_transition,
action=teleop_action,
compress_images=display_compressed_images,
@@ -242,9 +215,7 @@ def teleoperate(cfg: TeleoperateConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
init_visualization(
cfg.display_mode, session_name="teleoperation", ip=cfg.display_ip, port=cfg.display_port
)
init_rerun(session_name="teleoperation", ip=cfg.display_ip, port=cfg.display_port)
display_compressed_images = (
True
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
@@ -264,7 +235,6 @@ def teleoperate(cfg: TeleoperateConfig):
robot=robot,
fps=cfg.fps,
display_data=cfg.display_data,
display_mode=cfg.display_mode,
duration=cfg.teleop_time_s,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
@@ -275,7 +245,7 @@ def teleoperate(cfg: TeleoperateConfig):
pass
finally:
if cfg.display_data:
shutdown_visualization(cfg.display_mode)
shutdown_rerun()
teleop.disconnect()
robot.disconnect()
+2 -2
View File
@@ -736,9 +736,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
unwrapped_model = accelerator.unwrap_model(policy)
# PEFT only applies when training a policy — reward models use the plain path.
if not cfg.is_reward_model_training and cfg.policy.use_peft:
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model, dataset_meta=dataset.meta)
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
else:
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict, dataset_meta=dataset.meta)
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id)
@@ -65,7 +65,7 @@ class RebotArm102LeaderConfig:
joint_ranges: dict[str, list[int]] = field(
default_factory=lambda: {
"shoulder_pan": [-150, 150],
"shoulder_lift": [-200, 1],
"shoulder_lift": [-170, 1],
"elbow_flex": [-200, 1],
"wrist_flex": [-80, 90],
"wrist_yaw": [-90, 90],
-1
View File
@@ -37,7 +37,6 @@ ACTION_TOKEN_MASK = ACTION + ".token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
SUCCESS = "next.success"
INFO = "info"
ROBOTS = "robots"
-635
View File
@@ -1,635 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Foxglove visualization backend.
Live control-loop streaming (:func:`log_foxglove_data`) and seekable dataset playback
(:func:`serve_foxglove_dataset_playback`) over a Foxglove WebSocket server. Callers usually select a
backend at runtime through the dispatch in :mod:`lerobot.utils.visualization_utils` rather than
importing from here directly. Requires the ``viz`` extra (``pip install 'lerobot[viz]'``).
"""
import logging
import numbers
import time
import cv2
import numpy as np
from lerobot.types import RobotAction, RobotObservation
from .constants import (
ACTION,
ACTION_PREFIX,
DONE,
OBS_IMAGES,
OBS_PREFIX,
OBS_STATE,
OBS_STR,
REWARD,
SUCCESS,
TRUNCATED,
)
from .import_utils import require_package
# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}``
# pairs rather than one field per feature, so the same schema fits any robot regardless of which
# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name
# each series automatically, so a single filtered path plots every feature, e.g.
# ``/observation/state.scalars[:]``.
_SCALARS_SCHEMA = {
"type": "object",
"title": "lerobot.Scalars",
"properties": {
"scalars": {
"type": "array",
"items": {
"type": "object",
"properties": {
"label": {"type": "string"},
"value": {"type": "number"},
},
},
}
},
}
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def init_foxglove(host: str = "127.0.0.1", port: int | None = 8765) -> None:
"""
Starts a Foxglove WebSocket server for visualizing the control loop.
Connect to it from the Foxglove app at ``ws://<host>:<port>``. Calling this
more than once is a no-op while a server is already running.
Args:
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to (defaults to 8765).
"""
require_package("foxglove-sdk", extra="viz", import_name="foxglove")
import foxglove
# Live-stream state lives as attributes on ``log_foxglove_data``:
# ``.server`` is the shared WebSocket server and
# ``.channels`` caches one Foxglove channel per topic
if getattr(log_foxglove_data, "server", None) is not None:
return
log_foxglove_data.server = foxglove.start_server(host=host, port=port or 8765)
log_foxglove_data.channels = {}
def shutdown_foxglove() -> None:
"""Stops the Foxglove WebSocket server and clears cached channels."""
server = getattr(log_foxglove_data, "server", None)
if server is not None:
server.stop()
log_foxglove_data.server = None
log_foxglove_data.channels = {}
def _foxglove_safe_name(name: str) -> str:
"""Replace ``.`` with ``_`` so a feature name is a single Foxglove topic-path segment.
Foxglove treats ``.`` as a path separator, so an unsanitized name like ``observation.images.front``
would split into nested segments instead of naming one topic.
"""
return name.replace(".", "_")
def _foxglove_topic(key: str, *, is_image: bool = False) -> str:
"""Build the Foxglove topic for a feature ``key``.
Camera features map to a per-source image topic (``/observation/images/<name>``); scalar features
share one aggregate topic per source: ``/observation/state`` for observations, ``/action/state``
for actions.
"""
if is_image:
name = str(key)
for prefix in (f"{OBS_IMAGES}.", OBS_PREFIX):
if name.startswith(prefix):
name = name[len(prefix) :]
break
return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}"
source = ACTION if (str(key).startswith(ACTION_PREFIX) or str(key) == ACTION) else OBS_STR
return f"/{source}/state"
def _log_foxglove_scalars(
topic: str, values: dict[str, float], *, channels: dict | None = None, log_time: int | None = None
) -> None:
"""Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`.
``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of
``{label, value}`` objects. Insertion order is preserved so series stay stable across messages.
``channels`` is the per-topic channel cache to reuse (defaults to the live-stream cache on
:func:`log_foxglove_data`; dataset playback passes its own local cache to stay self-contained).
``log_time`` is the message time in nanoseconds; when ``None`` the server's receive time is used.
"""
if not values:
return
import foxglove
if channels is None:
channels = log_foxglove_data.channels
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = foxglove.Channel(topic, schema=_SCALARS_SCHEMA, message_encoding="json")
msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]}
if log_time is None:
channel.log(msg)
else:
channel.log(msg, log_time=log_time)
def _labeled_scalars(name: str, values, labels: list[str] | None = None) -> dict[str, float]:
"""Expand a 1D sequence into ``{label: value}`` entries with a consistent fallback."""
flat = [float(v) for v in values]
if labels is None or len(labels) != len(flat):
labels = [f"{name}_{i}" for i in range(len(flat))]
return dict(zip(labels, flat, strict=True))
def _log_foxglove_image(
topic: str,
frame_id: str,
arr: np.ndarray,
*,
compress_images: bool,
channels: dict | None = None,
log_time: int | None = None,
depth_range: tuple[float, float] | None = None,
) -> None:
"""Log an image on a cached per-topic channel.
Frames are cast to ``uint8`` and the encoding is chosen from the channel count: 1 => ``mono8``,
3 => ``rgb8`` (float input assumed in [0, 1]), 4 => ``rgba8``; other counts are skipped with a
warning. When ``compress_images`` is set, ``mono8`` and ``rgb8`` are JPEG-encoded instead.
Args:
topic: Foxglove topic to log on.
frame_id: Frame id stamped on the message.
arr: Image as HWC or CHW (CHW is transposed to HWC), any dtype.
compress_images: JPEG-encode ``mono8`` and ``rgb8`` frames; ignored for ``rgba8``.
channels: Per-topic channel cache to reuse (see :func:`_log_foxglove_scalars`).
log_time: Message time in nanoseconds, also written to the header timestamp; when ``None``
the server's receive time is used.
depth_range: ``(lo, hi)`` bounds used to clip a single-channel frame before it is encoded as
a regular image.
"""
from foxglove.channels import CompressedImageChannel, RawImageChannel
from foxglove.messages import CompressedImage, RawImage, Timestamp
if channels is None:
channels = log_foxglove_data.channels
time_ns = time.time_ns() if log_time is None else log_time
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
log_kwargs = {} if log_time is None else {"log_time": log_time}
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
height, width = arr.shape[0], arr.shape[1]
n_channels = 1 if arr.ndim == 2 else arr.shape[2]
# Apply depth range clipping to single channel depth maps.
if depth_range is not None and n_channels == 1:
lo, hi = depth_range
arr = arr.clip(lo, hi)
if n_channels == 3 and np.issubdtype(arr.dtype, np.floating):
arr = (arr * 255.0).clip(0, 255)
arr = np.ascontiguousarray(arr, dtype=np.uint8)
if compress_images and n_channels in (1, 3):
buf_src = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) if n_channels == 3 else arr
_, buf = cv2.imencode(".jpg", buf_src)
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = CompressedImageChannel(topic=topic)
channel.log(
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"),
**log_kwargs,
)
return
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(n_channels)
if encoding is None:
logging.warning(
"Foxglove: skipping image on topic '%s' with unsupported shape %s (%d channels); "
"expected 1 (mono8), 3 (rgb8), or 4 (rgba8) channels.",
topic,
tuple(arr.shape),
n_channels,
)
return
channel = channels.get(topic)
if channel is None:
channel = channels[topic] = RawImageChannel(topic=topic)
channel.log(
RawImage(
timestamp=timestamp,
frame_id=frame_id,
width=width,
height=height,
encoding=encoding,
step=width * n_channels,
data=arr.tobytes(),
),
**log_kwargs,
)
def log_foxglove_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to a Foxglove WebSocket server for real-time visualization.
Mirrors ``log_rerun_data`` but emits Foxglove messages over the server started by
:func:`init_foxglove`. Data is mapped as follows:
- Scalars (and elements of 1D arrays) are accumulated per source and logged on the
``/observation/state`` and ``/action/state`` topics as typed JSON messages using the static
``lerobot.Scalars`` schema: a ``scalars`` array of ``{label, value}`` objects (see
:data:`_SCALARS_SCHEMA`). The ``label`` field lets Foxglove name each series automatically, so
``/observation/state.scalars[:].value`` plots every feature at once.
- 3D NumPy arrays that resemble images are transposed from CHW to HWC when needed and logged on a
per-source topic (e.g. ``/observation/images/front``) as a ``RawImage`` (or a JPEG
``CompressedImage`` when ``compress_images`` is True).
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to JPEG-compress images before logging to save bandwidth in exchange
for CPU and quality.
"""
require_package("foxglove-sdk", extra="viz", import_name="foxglove")
if getattr(log_foxglove_data, "server", None) is None:
raise RuntimeError("init_foxglove() must be called before log_foxglove_data().")
now = time.time_ns()
if observation:
obs_scalars: dict[str, float] = {}
for k, v in observation.items():
if v is None:
continue
key = k[len(OBS_PREFIX) :] if str(k).startswith(OBS_PREFIX) else str(k)
if _is_scalar(v):
obs_scalars[key] = float(v)
elif isinstance(v, np.ndarray):
if v.ndim == 1:
obs_scalars.update(_labeled_scalars(key, v))
else:
_log_foxglove_image(
_foxglove_topic(k, is_image=True),
key,
v,
compress_images=compress_images,
log_time=now,
)
_log_foxglove_scalars(_foxglove_topic(OBS_STATE), obs_scalars, log_time=now)
if action:
action_scalars: dict[str, float] = {}
for k, v in action.items():
if v is None:
continue
key = k[len(ACTION_PREFIX) :] if str(k).startswith(ACTION_PREFIX) else str(k)
if _is_scalar(v):
action_scalars[key] = float(v)
elif isinstance(v, np.ndarray):
action_scalars.update(_labeled_scalars(key, v.flatten()))
_log_foxglove_scalars(_foxglove_topic(ACTION), action_scalars, log_time=now)
# ── Dataset playback over a Foxglove WebSocket server ─────────────────────
# A LeRobotDataset is random-access on disk, so rather than fire-and-forget a forward stream we
# advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays
# to in the Foxglove app. This relies on the SDK's PlaybackControl capability.
def _feature_dim_names(feature: dict | None) -> list[str] | None:
"""Best-effort per-dimension series labels for a 1D feature, or ``None`` to fall back to indices.
LeRobot records a feature's ``names`` inconsistently: a flat list (``["x", "y"]``), a category
mapping (``{"motors": ["motor_0", "motor_1"]}``), or a name->index mapping
(``{"delta_x": 0, "delta_y": 1}``). Each is handled, but labels are only returned when their count
matches the feature's 1D shape, so a malformed/mismatched ``names`` can't silently mislabel series.
"""
if not feature:
return None
shape = feature.get("shape")
dim = shape[0] if shape and len(shape) == 1 else None
names = feature.get("names")
labels: list[str] | None = None
if isinstance(names, dict):
values = list(names.values())
if values and all(isinstance(v, (list, tuple)) for v in values):
labels = [str(n) for group in values for n in group]
elif values and all(isinstance(v, int) and not isinstance(v, bool) for v in values):
labels = [name for name, _ in sorted(names.items(), key=lambda kv: kv[1])]
elif isinstance(names, (list, tuple)):
labels = [str(n) for n in names]
if labels is not None and dim is not None and len(labels) == dim:
return labels
return None
def _frame_to_scalars(sample: dict, key: str, labels: list[str] | None = None) -> dict[str, float]:
"""Flatten a frame's vector/scalar feature ``key`` into ``{label: value}`` entries.
``labels`` provides one name per dimension (from the dataset's feature metadata); when absent or
the wrong length, dimensions fall back to ``{name}_{i}`` (the short feature name), matching the
live stream so series names agree. A scalar feature becomes a single entry. Missing or ``None``
features yield an empty mapping.
"""
v = sample.get(key)
if v is None:
return {}
arr = v.numpy() if hasattr(v, "numpy") else np.asarray(v)
if key.startswith(OBS_PREFIX):
name = key[len(OBS_PREFIX) :]
elif key.startswith(ACTION_PREFIX):
name = key[len(ACTION_PREFIX) :]
else:
name = key
if arr.ndim == 0:
return {name: float(arr)}
return _labeled_scalars(name, arr.flatten(), labels)
def serve_foxglove_dataset_playback(
dataset,
episode_index: int,
*,
host: str = "127.0.0.1",
port: int = 8765,
compress_images: bool = False,
autoplay: bool = True,
) -> None:
"""Serve a single dataset episode to Foxglove as a seekable, scrubbable timeline.
Starts a Foxglove WebSocket server advertising the ``PlaybackControl`` capability over the
episode's time range. The Foxglove app drives play/pause/seek/speed; a background thread and a
``ServerListener`` read frames from the on-disk ``dataset`` on demand and log them stamped at
their dataset timestamps, so the user can scrub anywhere in the episode. Blocks until interrupted.
Args:
dataset: A ``LeRobotDataset`` loaded for the single episode to visualize.
episode_index: Index of the episode being visualized (used only for the session name).
host: Host interface to bind the WebSocket server to.
port: Port to bind the WebSocket server to.
compress_images: Whether to JPEG-compress camera frames before logging.
autoplay: If True, start playing automatically as soon as a client connects, instead of
waiting for the user to press play in the Foxglove app.
"""
require_package("foxglove-sdk", extra="viz", import_name="foxglove")
import bisect
import threading
import foxglove
from foxglove.websocket import (
Capability,
PlaybackCommand,
PlaybackControlRequest,
PlaybackState,
PlaybackStatus,
ServerListener,
)
# Per-frame timestamps in nanoseconds (read straight from the table, no video decode).
times_ns = [int(round(float(t) * 1e9)) for t in dataset.hf_dataset["timestamp"]]
n_frames = len(times_ns)
if n_frames == 0:
raise ValueError("Cannot visualize an empty episode.")
first_ns, last_ns = times_ns[0], times_ns[-1]
camera_keys = list(dataset.meta.camera_keys)
# Dataset-wide q01/q99 depth bounds (fallback min/max) used to normalize depth to [0, 1].
depth_ranges: dict[str, tuple[float, float]] = {}
for key in dataset.meta.depth_keys:
stats = (dataset.meta.stats or {}).get(key)
if not stats:
continue
lo = stats["q01"] if "q01" in stats else stats["min"]
hi = stats["q99"] if "q99" in stats else stats["max"]
depth_ranges[key] = (float(np.asarray(lo).item()), float(np.asarray(hi).item()))
# Per-dimension series labels from the dataset metadata (e.g. joint names), computed once.
scalar_labels = {
OBS_STATE: _feature_dim_names(dataset.meta.features.get(OBS_STATE)),
ACTION: _feature_dim_names(dataset.meta.features.get(ACTION)),
}
# Local channel cache so the playback server is self-contained and doesn't touch the live-stream cache.
channels: dict = {}
def emit_frame(i: int) -> None:
"""Log every channel for frame ``i`` stamped at its dataset timestamp."""
sample = dataset[i]
log_time = times_ns[i]
for key in camera_keys:
arr = sample.get(key)
if arr is None:
continue
arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr)
_log_foxglove_image(
_foxglove_topic(key, is_image=True),
key,
arr,
compress_images=compress_images,
channels=channels,
log_time=log_time,
depth_range=depth_ranges.get(key),
)
_log_foxglove_scalars(
_foxglove_topic(OBS_STATE),
_frame_to_scalars(sample, OBS_STATE, scalar_labels[OBS_STATE]),
channels=channels,
log_time=log_time,
)
_log_foxglove_scalars(
_foxglove_topic(ACTION),
_frame_to_scalars(sample, ACTION, scalar_labels[ACTION]),
channels=channels,
log_time=log_time,
)
episode_scalars = {}
for feat, label in (
(DONE, "done"),
(TRUNCATED, "truncated"),
(REWARD, "reward"),
(SUCCESS, "success"),
):
v = sample.get(feat)
if v is not None:
episode_scalars[label] = float(v)
_log_foxglove_scalars("/episode/state", episode_scalars, channels=channels, log_time=log_time)
lock = threading.Lock()
stop_event = threading.Event()
# Shared playback state, guarded by ``lock``. ``seek_idx`` is a one-shot request set by the
# listener and serviced by the playback loop, which is the *only* thread that emits frames (so
# concurrent random access into the on-disk dataset / video decoder never overlaps).
state = {
"status": PlaybackStatus.Paused,
"cursor": first_ns,
"speed": 1.0,
"last_idx": -1,
"seek_idx": None,
}
def index_at(t_ns: int) -> int:
return max(0, min(n_frames - 1, bisect.bisect_right(times_ns, t_ns) - 1))
# One-shot latch so autoplay fires only on the first client subscription.
autoplay_started = threading.Event()
class _PlaybackListener(ServerListener):
def on_subscribe(self, client, channel):
# Start playing automatically once a client actually connects (subscribes). Using the
# subscribe hook, rather than starting in Playing up front, means the timeline doesn't
# advance before anyone is watching. Fires once; the user can still pause/seek after.
if not autoplay:
return
with lock:
if autoplay_started.is_set() or state["status"] != PlaybackStatus.Paused:
return
autoplay_started.set()
state["status"] = PlaybackStatus.Playing
cursor, speed = state["cursor"], state["speed"]
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Playing, cursor, speed, False, ""))
def on_playback_control_request(self, req: PlaybackControlRequest):
# Only mutate state here; the playback loop performs all frame emission.
with lock:
did_seek = False
if req.seek_time is not None:
cursor = max(first_ns, min(last_ns, req.seek_time))
state["cursor"] = cursor
state["last_idx"] = state["seek_idx"] = index_at(cursor)
did_seek = True
if req.playback_speed and req.playback_speed > 0:
state["speed"] = req.playback_speed
if req.playback_command == PlaybackCommand.Play:
# Restarting from the end replays from the beginning.
if state["cursor"] >= last_ns:
state["cursor"] = first_ns
state["last_idx"] = state["seek_idx"] = 0
did_seek = True
state["status"] = PlaybackStatus.Playing
elif req.playback_command == PlaybackCommand.Pause:
state["status"] = PlaybackStatus.Paused
status, cursor, speed = state["status"], state["cursor"], state["speed"]
request_id = req.request_id or ""
return PlaybackState(status, cursor, speed, did_seek, request_id)
server = foxglove.start_server(
name=f"{dataset.repo_id}/episode_{episode_index}",
host=host,
port=port,
capabilities=[Capability.PlaybackControl, Capability.Time],
server_listener=_PlaybackListener(),
playback_time_range=(first_ns, last_ns),
)
def playback_loop() -> None:
# Cap how far the cursor may advance in a single tick. A slow frame decode (or any stall)
# would otherwise make ``dt`` huge and produce one enormous catch-up batch; clamping it makes
# playback trail wall-clock under a slow decoder while each tick emits a bounded frame range.
max_tick_dt_s = 0.25
prev = time.monotonic()
while not stop_event.is_set():
time.sleep(1.0 / 60.0)
ended = False
speed = 1.0
with lock:
now = time.monotonic()
dt = min(now - prev, max_tick_dt_s)
prev = now
# A queued seek is always serviced, even while paused, so scrubbing updates the view.
work = []
seek_idx = state["seek_idx"]
if seek_idx is not None:
state["seek_idx"] = None
work.append(seek_idx)
if state["status"] == PlaybackStatus.Playing:
cursor = state["cursor"] + int(dt * 1e9 * state["speed"])
start_idx = state["last_idx"] + 1
if cursor >= last_ns:
cursor, target, ended = last_ns, n_frames - 1, True
else:
target = index_at(cursor)
state["cursor"] = cursor
work.extend(range(start_idx, target + 1))
# cursor only grows while playing (seeks reset last_idx in the listener), so
# target >= last_idx here; a plain assignment is correct and clearer than max().
state["last_idx"] = target
if ended:
state["status"] = PlaybackStatus.Ended
if not work:
continue
cursor, speed = state["cursor"], state["speed"]
# Emit outside the lock; this is the only thread that calls emit_frame. Re-check
# stop_event between frames so shutdown stays responsive even mid-batch.
for i in work:
if stop_event.is_set():
break
emit_frame(i)
server.broadcast_time(cursor)
if ended:
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Ended, cursor, speed, False, ""))
# Emit the first frame so channels are advertised (done before the loop starts, so emission stays
# single-threaded). Late-connecting clients re-receive frames once they seek/play.
emit_frame(0)
with lock:
state["last_idx"] = 0
server.broadcast_time(first_ns)
server.broadcast_playback_state(PlaybackState(PlaybackStatus.Paused, first_ns, 1.0, True, ""))
thread = threading.Thread(target=playback_loop, name="foxglove-playback", daemon=True)
thread.start()
print(f"Foxglove server running. Connect the Foxglove app to ws://{host}:{port}")
print("Use the playback controls in Foxglove to play/pause and scrub the episode. Ctrl-C to exit.")
try:
while not stop_event.is_set():
time.sleep(0.5)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
finally:
stop_event.set()
thread.join(timeout=2.0)
server.stop()
channels.clear()
-184
View File
@@ -1,184 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rerun visualization backend.
Live control-loop streaming to the Rerun viewer (:func:`log_rerun_data`). Callers usually select a
backend at runtime through the dispatch in :mod:`lerobot.utils.visualization_utils` rather than
importing from here directly. Requires the ``viz`` extra (``pip install 'lerobot[viz]'``).
"""
import numbers
import os
import numpy as np
from lerobot.types import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
from .import_utils import require_package
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def init_rerun(
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
) -> None:
"""
Initializes the Rerun SDK for visualizing the control loop.
Args:
session_name: Name of the Rerun session.
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
log_rerun_data.blueprint = None # Reset blueprint cache for new session
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
if ip and port:
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
else:
rr.spawn(memory_limit=memory_limit)
def shutdown_rerun() -> None:
"""Shuts down the Rerun SDK gracefully."""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
rr.rerun_shutdown()
def _build_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]):
"""Build a Rerun blueprint laying out camera images, observation and action scalars in separate views.
Camera images, observation and action scalars are arranged in a grid.
"""
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun.blueprint as rrb
views = [rrb.Spatial2DView(origin=path, name=path) for path in sorted(image_paths)]
if observation_paths:
views.append(rrb.TimeSeriesView(name="observation", contents=sorted(observation_paths)))
if action_paths:
views.append(rrb.TimeSeriesView(name="action", contents=sorted(action_paths)))
return rrb.Blueprint(rrb.Grid(*views))
def _ensure_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]) -> None:
"""Build and send the blueprint once, from the first observation and action data."""
if getattr(log_rerun_data, "blueprint", None) is not None:
return
if not (observation_paths or action_paths or image_paths):
return
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun as rr
blueprint = _build_blueprint(observation_paths, action_paths, image_paths)
log_rerun_data.blueprint = blueprint
rr.send_blueprint(blueprint)
def log_rerun_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""
Logs observation and action data to Rerun for real-time visualization.
This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately:
- Scalars values (floats, ints) are logged as `rr.Scalars`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every
dimension shares the same view instead of being split across one view per element.
- Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch.
Keys are automatically namespaced with "observation." or "action." if not already present.
On the first call, a blueprint is built and sent so observation and action scalars get separate
time-series views and each image gets its own spatial view.
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
observation_paths: set[str] = set()
action_paths: set[str] = set()
image_paths: set[str] = set()
if observation:
for k, v in observation.items():
if v is None:
continue
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
observation_paths.add(key)
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
rr.log(key, rr.Scalars(arr.astype(float)))
observation_paths.add(key)
else:
if arr.shape[-1] == 1:
img_entity = rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
image_paths.add(key)
if action:
for k, v in action.items():
if v is None:
continue
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
action_paths.add(key)
elif isinstance(v, np.ndarray):
# Flatten any (incl. higher-dimensional) array into a single batched Scalars
rr.log(key, rr.Scalars(v.reshape(-1).astype(float)))
action_paths.add(key)
_ensure_blueprint(observation_paths, action_paths, image_paths)
+142 -44
View File
@@ -12,68 +12,166 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Backend-agnostic visualization dispatch.
import numbers
import os
Selects a visualization backend at runtime via a display-mode string (e.g. a ``--display_mode`` CLI
flag) so callers never branch on the backend. The concrete implementations live in
:mod:`lerobot.utils.rerun_visualization` and :mod:`lerobot.utils.foxglove_visualization`; importing
this module does not import ``rerun`` or ``foxglove`` (each backend imports its SDK lazily behind a
``require_package`` guard).
"""
import numpy as np
from lerobot.types import RobotAction, RobotObservation
from .foxglove_visualization import init_foxglove, log_foxglove_data, shutdown_foxglove
from .rerun_visualization import init_rerun, log_rerun_data, shutdown_rerun
# Visualization backends selectable at runtime via a display-mode string (e.g. a --display_mode flag).
VISUALIZATION_MODES = ("rerun", "foxglove")
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
from .import_utils import require_package
def init_visualization(
display_mode: str,
*,
session_name: str = "lerobot_control_loop",
ip: str | None = None,
port: int | None = None,
def init_rerun(
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
) -> None:
"""Initializes the visualization backend selected by ``display_mode``.
"""
Initializes the Rerun SDK for visualizing the control loop.
For ``"rerun"``, ``ip``/``port`` point at an optional remote Rerun server. For ``"foxglove"``,
``ip`` is the interface to bind the WebSocket server to (``127.0.0.1`` for local only, ``0.0.0.0``
for all interfaces) and ``port`` is its port.
Args:
session_name: Name of the Rerun session.
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
"""
if display_mode == "rerun":
init_rerun(session_name=session_name, ip=ip, port=port)
elif display_mode == "foxglove":
init_foxglove(host=ip or "127.0.0.1", port=port)
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
log_rerun_data.blueprint = None # Reset blueprint cache for new session
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
if ip and port:
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
else:
raise ValueError(f"Unknown display_mode '{display_mode}'. Expected one of {VISUALIZATION_MODES}.")
rr.spawn(memory_limit=memory_limit)
def log_visualization_data(
display_mode: str,
def shutdown_rerun() -> None:
"""Shuts down the Rerun SDK gracefully."""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
rr.rerun_shutdown()
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
def _build_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]):
"""Build a Rerun blueprint laying out camera images, observation and action scalars in separate views.
Camera images, observation and action scalars are arranged in a grid.
"""
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun.blueprint as rrb
views = [rrb.Spatial2DView(origin=path, name=path) for path in sorted(image_paths)]
if observation_paths:
views.append(rrb.TimeSeriesView(name="observation", contents=sorted(observation_paths)))
if action_paths:
views.append(rrb.TimeSeriesView(name="action", contents=sorted(action_paths)))
return rrb.Blueprint(rrb.Grid(*views))
def _ensure_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]) -> None:
"""Build and send the blueprint once, from the first observation and action data."""
if getattr(log_rerun_data, "blueprint", None) is not None:
return
if not (observation_paths or action_paths or image_paths):
return
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun as rr
blueprint = _build_blueprint(observation_paths, action_paths, image_paths)
log_rerun_data.blueprint = blueprint
rr.send_blueprint(blueprint)
def log_rerun_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
) -> None:
"""Logs observation/action data to the backend selected by ``display_mode``."""
"""
Logs observation and action data to Rerun for real-time visualization.
if display_mode == "rerun":
log_rerun_data(observation=observation, action=action, compress_images=compress_images)
elif display_mode == "foxglove":
log_foxglove_data(observation=observation, action=action, compress_images=compress_images)
else:
raise ValueError(f"Unknown display_mode '{display_mode}'. Expected one of {VISUALIZATION_MODES}.")
This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately:
- Scalars values (floats, ints) are logged as `rr.Scalars`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every
dimension shares the same view instead of being split across one view per element.
- Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch.
Keys are automatically namespaced with "observation." or "action." if not already present.
def shutdown_visualization(display_mode: str) -> None:
"""Shuts down the backend selected by ``display_mode``."""
On the first call, a blueprint is built and sent so observation and action scalars get separate
time-series views and each image gets its own spatial view.
if display_mode == "rerun":
shutdown_rerun()
elif display_mode == "foxglove":
shutdown_foxglove()
else:
raise ValueError(f"Unknown display_mode '{display_mode}'. Expected one of {VISUALIZATION_MODES}.")
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
observation_paths: set[str] = set()
action_paths: set[str] = set()
image_paths: set[str] = set()
if observation:
for k, v in observation.items():
if v is None:
continue
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
observation_paths.add(key)
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
rr.log(key, rr.Scalars(arr.astype(float)))
observation_paths.add(key)
else:
if arr.shape[-1] == 1:
img_entity = rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
image_paths.add(key)
if action:
for k, v in action.items():
if v is None:
continue
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
action_paths.add(key)
elif isinstance(v, np.ndarray):
# Flatten any (incl. higher-dimensional) array into a single batched Scalars
rr.log(key, rr.Scalars(v.reshape(-1).astype(float)))
action_paths.add(key)
_ensure_blueprint(observation_paths, action_paths, image_paths)
+2 -1
View File
@@ -1531,6 +1531,7 @@ def test_valid_video_codecs_constant():
assert "h264" in VALID_VIDEO_CODECS
assert "hevc" in VALID_VIDEO_CODECS
assert "libsvtav1" in VALID_VIDEO_CODECS
assert "libaom-av1" in VALID_VIDEO_CODECS
assert "auto" in VALID_VIDEO_CODECS
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
assert "h264_nvenc" in VALID_VIDEO_CODECS
@@ -1538,7 +1539,7 @@ def test_valid_video_codecs_constant():
assert "h264_qsv" in VALID_VIDEO_CODECS
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
assert "hevc_nvenc" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 10
assert len(VALID_VIDEO_CODECS) == 11
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
+3 -20
View File
@@ -91,11 +91,10 @@ def test_get_observation_converts_to_degrees(follower):
def test_send_action_clips_to_joint_limits(follower):
# shoulder_pan limit is (-150, 150); request beyond the upper bound.
# shoulder_pan limit is (-145, 145); request beyond the upper bound.
returned = follower.send_action({"shoulder_pan.pos": 999.0})
assert returned["shoulder_pan.pos"] == 150.0
# Default control_mode is "mit", so arm joints are driven via send_mit.
follower.motors["shoulder_pan"].send_mit.assert_called_once()
assert returned["shoulder_pan.pos"] == 145.0
follower.motors["shoulder_pan"].send_pos_vel.assert_called_once()
def test_send_action_routes_gripper_to_force_pos(follower):
@@ -104,22 +103,6 @@ def test_send_action_routes_gripper_to_force_pos(follower):
follower.motors["gripper"].send_pos_vel.assert_not_called()
def test_gripper_mit_mode_routes_to_send_mit():
bus_mock = _make_bus_mock()
with (
patch(f"{_MODULE}.require_package", lambda *a, **kw: None),
patch(f"{_MODULE}.MotorBridgeController") as controller_cls,
patch(f"{_MODULE}.MotorBridgeMode", MagicMock()),
):
controller_cls.from_dm_serial.return_value = bus_mock
cfg = RebotB601FollowerRobotConfig(port="/dev/null", gripper_control_mode="mit")
robot = RebotB601Follower(cfg)
robot.connect(calibrate=False)
robot.send_action({"gripper.pos": -10.0})
robot.motors["gripper"].send_mit.assert_called_once()
robot.motors["gripper"].send_force_pos.assert_not_called()
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebotB601FollowerConfig(
-101
View File
@@ -1,101 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Foxglove backend's pure helpers.
These cover topic naming, series labelling and feature-name parsing. They import
``foxglove_visualization`` directly and need NO ``foxglove`` extra: the SDK is imported lazily inside
the functions that talk to the server, so the helpers below run in the base test tier.
"""
import numpy as np
from lerobot.utils import foxglove_visualization as fv
from lerobot.utils.constants import ACTION, OBS_STATE
def test_foxglove_safe_name_collapses_dots():
assert fv._foxglove_safe_name("observation.images.front") == "observation_images_front"
assert fv._foxglove_safe_name("plain") == "plain"
def test_foxglove_topic_image_strips_prefix_without_doubling_images():
# Fully-qualified camera key -> single clean segment (no doubled "images").
assert fv._foxglove_topic("observation.images.front", is_image=True) == "/observation/images/front"
# A nested camera name keeps its structure via safe-name collapsing.
assert (
fv._foxglove_topic("observation.images.wrist.left", is_image=True) == "/observation/images/wrist_left"
)
# Bare camera name (as real robots emit).
assert fv._foxglove_topic("front", is_image=True) == "/observation/images/front"
def test_foxglove_topic_scalar_sources():
assert fv._foxglove_topic(OBS_STATE) == "/observation/state"
assert fv._foxglove_topic("observation.environment_state") == "/observation/state"
assert fv._foxglove_topic(ACTION) == "/action/state"
assert fv._foxglove_topic("action.delta") == "/action/state"
def test_labeled_scalars_uses_labels_then_index_fallback():
assert fv._labeled_scalars("state", np.array([1.0, 2.0, 3.0])) == {
"state_0": 1.0,
"state_1": 2.0,
"state_2": 3.0,
}
assert fv._labeled_scalars("state", [1.0, 2.0], ["pan", "lift"]) == {"pan": 1.0, "lift": 2.0}
# Wrong-length labels fall back to index naming (never silently mislabels).
assert fv._labeled_scalars("q", [1.0, 2.0], ["only_one"]) == {"q_0": 1.0, "q_1": 2.0}
def test_frame_to_scalars_matches_live_labeling_and_handles_scalar():
frame = {OBS_STATE: np.array([1.0, 2.0])}
# No metadata -> {short_name}_{i}, identical to the live-stream fallback.
assert fv._frame_to_scalars(frame, OBS_STATE) == fv._labeled_scalars("state", np.array([1.0, 2.0]))
assert fv._frame_to_scalars(frame, OBS_STATE) == {"state_0": 1.0, "state_1": 2.0}
# Metadata labels are honored.
assert fv._frame_to_scalars(frame, OBS_STATE, ["pan", "lift"]) == {"pan": 1.0, "lift": 2.0}
# A 0-d scalar becomes a single entry named by the short feature name.
assert fv._frame_to_scalars({ACTION: np.array(5.0)}, ACTION) == {"action": 5.0}
# A missing feature yields an empty mapping.
assert fv._frame_to_scalars({}, OBS_STATE) == {}
def test_feature_dim_names_formats():
# Flat list of names.
assert fv._feature_dim_names({"shape": [2], "names": ["x", "y"]}) == ["x", "y"]
# Category mapping (dict of lists).
assert fv._feature_dim_names({"shape": [2], "names": {"motors": ["m0", "m1"]}}) == ["m0", "m1"]
# name -> index mapping (returned sorted by index).
assert fv._feature_dim_names({"shape": [2], "names": {"delta_x": 0, "delta_y": 1}}) == [
"delta_x",
"delta_y",
]
# Bool values must NOT be treated as an index map (bool is a subclass of int).
assert fv._feature_dim_names({"shape": [2], "names": {"a": True, "b": False}}) is None
# Mismatched length -> None (won't silently mislabel).
assert fv._feature_dim_names({"shape": [3], "names": ["x", "y"]}) is None
# Missing / absent names -> None.
assert fv._feature_dim_names(None) is None
assert fv._feature_dim_names({"shape": [2]}) is None
def test_is_scalar():
assert fv._is_scalar(1.0)
assert fv._is_scalar(np.float32(2.0))
assert fv._is_scalar(np.array(3.0)) # 0-d array
assert not fv._is_scalar(np.array([1.0, 2.0]))
assert not fv._is_scalar("x")
-310
View File
@@ -1,310 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import sys
from types import SimpleNamespace
import numpy as np
import pytest
pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])")
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
@pytest.fixture
def mock_rerun(monkeypatch):
"""
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
depend on the real library. Also reload the module-under-test so it binds to
this mock `rr`.
"""
calls = []
blueprints = []
class DummyScalar:
def __init__(self, value):
# Scalars may be built from a single float or from a 1D array batch.
self.value = value
class DummyImage:
def __init__(self, arr):
self.arr = arr
def compress(self, *a, **k):
return self
class DummyDepthImage:
def __init__(self, arr, colormap=None):
self.arr = arr
self.colormap = colormap
def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs:
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
def dummy_send_blueprint(blueprint, *a, **k):
blueprints.append(blueprint)
# Mock the `rerun.blueprint` submodule used to build the layout.
dummy_rrb = SimpleNamespace(
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
kind="Spatial2DView", origin=origin, name=name
),
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
kind="TimeSeriesView", name=name, contents=contents
),
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
)
dummy_rr = SimpleNamespace(
__name__="rerun",
__package__="rerun",
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
Image=DummyImage,
DepthImage=DummyDepthImage,
components=SimpleNamespace(Colormap=SimpleNamespace(Viridis="viridis")),
log=dummy_log,
send_blueprint=dummy_send_blueprint,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
blueprint=dummy_rrb,
)
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
# Now import and reload the module under test, to bind to our rerun mock
import lerobot.utils.rerun_visualization as rv
importlib.reload(rv)
# Expose the reloaded module, the call recorder and the captured blueprints
yield rv, calls, blueprints
def _keys(calls):
"""Helper to extract just the keys logged to rr.log"""
return [k for (k, _obj, _kw) in calls]
def _obj_for(calls, key):
"""Find the first object logged under a given key."""
for k, obj, _kw in calls:
if k == key:
return obj
raise KeyError(f"Key {key} not found in calls: {calls}")
def _kwargs_for(calls, key):
for k, _obj, kw in calls:
if k == key:
return kw
raise KeyError(f"Key {key} not found in calls: {calls}")
def _views_by_kind(blueprint, kind):
"""Return the views of a given kind from the (single) blueprint's grid."""
return [v for v in blueprint.root.views if v.kind == kind]
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
rv, calls, blueprints = mock_rerun
# Build EnvTransition dict
obs = {
f"{OBS_STATE}.temperature": np.float32(25.0),
# CHW image should be converted to HWC for rr.Image
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
}
act = {
"action.throttle": 0.7,
# 1D array should be logged as a single Scalars batch under one entity path
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
}
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: act,
}
# Extract observation and action data from transition like in the real call sites
obs_data = transition.get(TransitionKey.OBSERVATION, {})
action_data = transition.get(TransitionKey.ACTION, {})
rv.log_rerun_data(observation=obs_data, action=action_data)
# We expect:
# - observation.state.temperature -> Scalars
# - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars
# - action.vector -> single Scalars batch (no per-element suffix)
expected_keys = {
f"{OBS_STATE}.temperature",
"observation.camera",
"action.throttle",
"action.vector",
}
assert set(_keys(calls)) == expected_keys
# Check scalar types and values
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar"
assert float(temp_obj.value) == pytest.approx(25.0)
throttle_obj = _obj_for(calls, "action.throttle")
assert type(throttle_obj).__name__ == "DummyScalar"
assert float(throttle_obj.value) == pytest.approx(0.7)
# 1D vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vector")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
# Check image handling: CHW -> HWC
img_obj = _obj_for(calls, "observation.camera")
assert type(img_obj).__name__ == "DummyImage"
assert img_obj.arr.shape == (10, 20, 3) # transposed
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
# A blueprint should have been built and sent exactly once, and cached on the function.
assert len(blueprints) == 1
assert rv.log_rerun_data.blueprint is blueprints[0]
bp = blueprints[0]
# One spatial view per image path
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.camera"}
# One time-series view each for observation and action scalars
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert set(ts_views) == {"observation", "action"}
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
rv, calls, blueprints = mock_rerun
# First dict without prefixes treated as observation
# Second dict without prefixes treated as action
obs_plain = {
"temp": 1.5,
# Already HWC image => should stay as-is
"img": np.zeros((5, 6, 3), dtype=np.uint8),
"none": None, # should be skipped
}
act_plain = {
"throttle": 0.3,
"vec": np.array([9, 8, 7], dtype=np.float32),
}
# Extract observation and action data from list like the old function logic did
# First dict was treated as observation, second as action
rv.log_rerun_data(observation=obs_plain, action=act_plain)
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
expected = {
"observation.temp",
"observation.img",
"action.throttle",
"action.vec",
}
logged = set(_keys(calls))
assert logged == expected
# Scalars
t = _obj_for(calls, "observation.temp")
assert type(t).__name__ == "DummyScalar"
assert float(t.value) == pytest.approx(1.5)
throttle = _obj_for(calls, "action.throttle")
assert type(throttle).__name__ == "DummyScalar"
assert float(throttle.value) == pytest.approx(0.3)
# Image stays HWC
img = _obj_for(calls, "observation.img")
assert type(img).__name__ == "DummyImage"
assert img.arr.shape == (5, 6, 3)
assert _kwargs_for(calls, "observation.img").get("static", False) is True
# Vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vec")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
# Blueprint sent once with the expected view layout
assert len(blueprints) == 1
bp = blueprints[0]
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.img"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
def test_log_rerun_data_kwargs_only(mock_rerun):
rv, calls, blueprints = mock_rerun
rv.log_rerun_data(
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
action={"action.a": 1.0},
)
keys = set(_keys(calls))
assert "observation.temp" in keys
assert "observation.gray" in keys
assert "action.a" in keys
temp = _obj_for(calls, "observation.temp")
assert type(temp).__name__ == "DummyScalar"
assert float(temp.value) == pytest.approx(10.0)
img = _obj_for(calls, "observation.gray")
assert type(img).__name__ == "DummyDepthImage" # single-channel -> DepthImage
assert img.arr.shape == (8, 8, 1) # remains HWC
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert float(a.value) == pytest.approx(1.0)
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
assert len(blueprints) == 1
bp = blueprints[0]
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.a"]
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
"""The blueprint is built from the first call and not resent on subsequent calls."""
rv, calls, blueprints = mock_rerun
rv.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
assert len(blueprints) == 1
first_blueprint = rv.log_rerun_data.blueprint
rv.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
# Still only one blueprint, and the cached one is unchanged.
assert len(blueprints) == 1
assert rv.log_rerun_data.blueprint is first_blueprint
+287 -13
View File
@@ -14,23 +14,297 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the backend-agnostic visualization dispatch.
These exercise the display-mode routing/validation only; they need neither ``rerun`` nor
``foxglove`` installed since the unknown-mode branch raises before touching any backend. Backend
behavior is covered in ``test_rerun_visualization.py`` and ``test_foxglove_visualization.py``.
"""
import importlib
import sys
from types import SimpleNamespace
import numpy as np
import pytest
from lerobot.utils import visualization_utils as vu
pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])")
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
def test_visualization_modes():
assert vu.VISUALIZATION_MODES == ("rerun", "foxglove")
@pytest.fixture
def mock_rerun(monkeypatch):
"""
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
depend on the real library. Also reload the module-under-test so it binds to
this mock `rr`.
"""
calls = []
blueprints = []
class DummyScalar:
def __init__(self, value):
# Scalars may be built from a single float or from a 1D array batch.
self.value = value
class DummyImage:
def __init__(self, arr):
self.arr = arr
def compress(self, *a, **k):
return self
class DummyDepthImage:
def __init__(self, arr, colormap=None):
self.arr = arr
self.colormap = colormap
def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs:
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
def dummy_send_blueprint(blueprint, *a, **k):
blueprints.append(blueprint)
# Mock the `rerun.blueprint` submodule used to build the layout.
dummy_rrb = SimpleNamespace(
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
kind="Spatial2DView", origin=origin, name=name
),
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
kind="TimeSeriesView", name=name, contents=contents
),
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
)
dummy_rr = SimpleNamespace(
__name__="rerun",
__package__="rerun",
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
Image=DummyImage,
DepthImage=DummyDepthImage,
components=SimpleNamespace(Colormap=SimpleNamespace(Viridis="viridis")),
log=dummy_log,
send_blueprint=dummy_send_blueprint,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
blueprint=dummy_rrb,
)
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
# Now import and reload the module under test, to bind to our rerun mock
import lerobot.utils.visualization_utils as vu
importlib.reload(vu)
# Expose the reloaded module, the call recorder and the captured blueprints
yield vu, calls, blueprints
@pytest.mark.parametrize("func", ["init_visualization", "log_visualization_data", "shutdown_visualization"])
def test_dispatch_rejects_unknown_mode(func):
with pytest.raises(ValueError, match="Unknown display_mode"):
getattr(vu, func)("bogus")
def _keys(calls):
"""Helper to extract just the keys logged to rr.log"""
return [k for (k, _obj, _kw) in calls]
def _obj_for(calls, key):
"""Find the first object logged under a given key."""
for k, obj, _kw in calls:
if k == key:
return obj
raise KeyError(f"Key {key} not found in calls: {calls}")
def _kwargs_for(calls, key):
for k, _obj, kw in calls:
if k == key:
return kw
raise KeyError(f"Key {key} not found in calls: {calls}")
def _views_by_kind(blueprint, kind):
"""Return the views of a given kind from the (single) blueprint's grid."""
return [v for v in blueprint.root.views if v.kind == kind]
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
vu, calls, blueprints = mock_rerun
# Build EnvTransition dict
obs = {
f"{OBS_STATE}.temperature": np.float32(25.0),
# CHW image should be converted to HWC for rr.Image
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
}
act = {
"action.throttle": 0.7,
# 1D array should be logged as a single Scalars batch under one entity path
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
}
transition = {
TransitionKey.OBSERVATION: obs,
TransitionKey.ACTION: act,
}
# Extract observation and action data from transition like in the real call sites
obs_data = transition.get(TransitionKey.OBSERVATION, {})
action_data = transition.get(TransitionKey.ACTION, {})
vu.log_rerun_data(observation=obs_data, action=action_data)
# We expect:
# - observation.state.temperature -> Scalars
# - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars
# - action.vector -> single Scalars batch (no per-element suffix)
expected_keys = {
f"{OBS_STATE}.temperature",
"observation.camera",
"action.throttle",
"action.vector",
}
assert set(_keys(calls)) == expected_keys
# Check scalar types and values
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar"
assert float(temp_obj.value) == pytest.approx(25.0)
throttle_obj = _obj_for(calls, "action.throttle")
assert type(throttle_obj).__name__ == "DummyScalar"
assert float(throttle_obj.value) == pytest.approx(0.7)
# 1D vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vector")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
# Check image handling: CHW -> HWC
img_obj = _obj_for(calls, "observation.camera")
assert type(img_obj).__name__ == "DummyImage"
assert img_obj.arr.shape == (10, 20, 3) # transposed
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
# A blueprint should have been built and sent exactly once, and cached on the function.
assert len(blueprints) == 1
assert vu.log_rerun_data.blueprint is blueprints[0]
bp = blueprints[0]
# One spatial view per image path
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.camera"}
# One time-series view each for observation and action scalars
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert set(ts_views) == {"observation", "action"}
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
vu, calls, blueprints = mock_rerun
# First dict without prefixes treated as observation
# Second dict without prefixes treated as action
obs_plain = {
"temp": 1.5,
# Already HWC image => should stay as-is
"img": np.zeros((5, 6, 3), dtype=np.uint8),
"none": None, # should be skipped
}
act_plain = {
"throttle": 0.3,
"vec": np.array([9, 8, 7], dtype=np.float32),
}
# Extract observation and action data from list like the old function logic did
# First dict was treated as observation, second as action
vu.log_rerun_data(observation=obs_plain, action=act_plain)
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
expected = {
"observation.temp",
"observation.img",
"action.throttle",
"action.vec",
}
logged = set(_keys(calls))
assert logged == expected
# Scalars
t = _obj_for(calls, "observation.temp")
assert type(t).__name__ == "DummyScalar"
assert float(t.value) == pytest.approx(1.5)
throttle = _obj_for(calls, "action.throttle")
assert type(throttle).__name__ == "DummyScalar"
assert float(throttle.value) == pytest.approx(0.3)
# Image stays HWC
img = _obj_for(calls, "observation.img")
assert type(img).__name__ == "DummyImage"
assert img.arr.shape == (5, 6, 3)
assert _kwargs_for(calls, "observation.img").get("static", False) is True
# Vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vec")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
# Blueprint sent once with the expected view layout
assert len(blueprints) == 1
bp = blueprints[0]
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.img"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
def test_log_rerun_data_kwargs_only(mock_rerun):
vu, calls, blueprints = mock_rerun
vu.log_rerun_data(
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
action={"action.a": 1.0},
)
keys = set(_keys(calls))
assert "observation.temp" in keys
assert "observation.gray" in keys
assert "action.a" in keys
temp = _obj_for(calls, "observation.temp")
assert type(temp).__name__ == "DummyScalar"
assert float(temp.value) == pytest.approx(10.0)
img = _obj_for(calls, "observation.gray")
assert type(img).__name__ == "DummyDepthImage" # single-channel -> DepthImage
assert img.arr.shape == (8, 8, 1) # remains HWC
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert float(a.value) == pytest.approx(1.0)
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
assert len(blueprints) == 1
bp = blueprints[0]
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.a"]
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
"""The blueprint is built from the first call and not resent on subsequent calls."""
vu, calls, blueprints = mock_rerun
vu.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
assert len(blueprints) == 1
first_blueprint = vu.log_rerun_data.blueprint
vu.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
# Still only one blueprint, and the cached one is unchanged.
assert len(blueprints) == 1
assert vu.log_rerun_data.blueprint is first_blueprint
Generated
+1 -26
View File
@@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.12"
resolution-markers = [
"(python_full_version >= '3.15' and platform_machine == 'AMD64' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'linux')",
@@ -1550,26 +1550,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/47/c99d5268f354002ce80f8d029cd9d7d872969da1de8b93d32de4dc56d6f4/fonttools-4.63.0-py3-none-any.whl", hash = "sha256:445af2eab030a16b9171ea8bdda7ebf7d96bda2df88ee182a464252f6e05e20d", size = 1164562, upload-time = "2026-05-14T12:04:29.092Z" },
]
[[package]]
name = "foxglove-sdk"
version = "0.25.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c1/a7/86a252782ea0d9baf1357369ad1bbf1ed644768702b0266a3fa3a05361d0/foxglove_sdk-0.25.1.tar.gz", hash = "sha256:8230f3c32ea3ab715818687377491594ec9c7e58e6b0ed8ed91aadf937ce706b", size = 547778, upload-time = "2026-06-02T03:13:18.942Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/58/15/59f02e8201b8da09ce05d8774820c29efc9149862b70ee6b3a27968e791a/foxglove_sdk-0.25.1-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:5af9f9a691eefbe6e0a47875ff2f7d0fc36607f0920e8690bbdc2dfd4fb22451", size = 17911538, upload-time = "2026-06-02T03:13:12.493Z" },
{ url = "https://files.pythonhosted.org/packages/27/ed/16d809fab24cbfdf97c15c9cdd80eabfeb447ca545ede426950d62bac848/foxglove_sdk-0.25.1-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:3e908bd87d1926a05c785779d8252db6b87eef685f284ec1cf46ee501645d08e", size = 16452309, upload-time = "2026-06-02T03:13:10.607Z" },
{ url = "https://files.pythonhosted.org/packages/d6/c3/f95874935a3436841487df1f0202de4d20eabc0adb6b79c94c531bbe7eb3/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:968e32c8668d172f6b546c8e7af658ed35a21ec165adc3bacf53a04dda159f12", size = 2355680, upload-time = "2026-06-02T02:34:01.668Z" },
{ url = "https://files.pythonhosted.org/packages/38/da/ad22d8d6e3fedde9fc0c49aa8b20394e5e0bc44ab3fba564c77a64ddc7e2/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3f75374fedafe259c40b19bc645589d9453708eab679a5b07c603035f936d29a", size = 2274075, upload-time = "2026-06-02T02:34:07.212Z" },
{ url = "https://files.pythonhosted.org/packages/a3/fa/1254adb5e72eff507695473e9c82d0e90395b61463e5353762250db30d3d/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b6d3af517a00342bf7b08a4a65b043f3eafaa197138752b6fbd704fb91043fa", size = 2282160, upload-time = "2026-06-02T02:34:08.812Z" },
{ url = "https://files.pythonhosted.org/packages/7c/e4/2b22ef06ba4058494c7aa35974d138f8f1ae4cf5273f77d69c9dc3a99b45/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:aed27c0f03a45fd6abdd566498bfee2672391602bcff32c827b8e3a6d8f67ab1", size = 22685338, upload-time = "2026-06-02T02:34:04.688Z" },
{ url = "https://files.pythonhosted.org/packages/35/7c/58324c99b80eef0b674c8d4f5c2e07c66fd1480a27a8f0d4d79371805111/foxglove_sdk-0.25.1-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:419dd8308e3f91e2ae487b727f1bf1804642990876163b2a353db4a1b1de1425", size = 19326096, upload-time = "2026-06-02T02:34:10.939Z" },
{ url = "https://files.pythonhosted.org/packages/fe/9c/3452d92959e05fc6b1c1e5f032605d55623aeb6704357d20408f8781bc84/foxglove_sdk-0.25.1-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0fcb36e628ab3d9043e193f12ad4dbbb955fe18616aac7ef5bca82c52910f108", size = 2539020, upload-time = "2026-06-02T03:13:14.365Z" },
{ url = "https://files.pythonhosted.org/packages/b5/af/57fa58525d3acb5c5480a6f0ef86450b1a0ccae2b21248edb1376073ce55/foxglove_sdk-0.25.1-cp310-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:7909fd9f94935935dd8813702d84ffdbfebeb3866673c618ce35e8cfedd03029", size = 2550999, upload-time = "2026-06-02T03:13:15.715Z" },
{ url = "https://files.pythonhosted.org/packages/90/78/f74bb167186c965d475ff360fa6eb7441d5ac6c6239d60f542f63984f849/foxglove_sdk-0.25.1-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:69d5966213b5212b8841b4004fe582db924a74f1610d8452ad890f6931702926", size = 2560166, upload-time = "2026-06-02T03:13:17.254Z" },
{ url = "https://files.pythonhosted.org/packages/81/83/1c4c6d04fbd4784fe44fb2da021db1adf1f03a371f1e5679a383c1173235/foxglove_sdk-0.25.1-cp310-abi3-win32.whl", hash = "sha256:2a1121a5c74590ff6e61628c4a46dc57d392d290b4beeb29d6852933da56224a", size = 1618124, upload-time = "2026-06-02T03:13:20.158Z" },
{ url = "https://files.pythonhosted.org/packages/5f/4d/bdb9e252a41a951eb53908ac9cb965b7480c3ba649174f5398d4fcf0ca1d/foxglove_sdk-0.25.1-cp310-abi3-win_amd64.whl", hash = "sha256:6ed3ad0d3e72cd7875e7e293709c5ff90494fe14f1b48a336baffc313a7272cc", size = 16588452, upload-time = "2026-06-02T03:13:21.636Z" },
]
[[package]]
name = "fqdn"
version = "1.5.1"
@@ -2831,7 +2811,6 @@ all = [
{ name = "faker" },
{ name = "fastapi" },
{ name = "feetech-servo-sdk" },
{ name = "foxglove-sdk" },
{ name = "grpcio" },
{ name = "grpcio-tools" },
{ name = "gym-aloha" },
@@ -2916,7 +2895,6 @@ core-scripts = [
{ name = "av" },
{ name = "datasets" },
{ name = "deepdiff" },
{ name = "foxglove-sdk" },
{ name = "jsonlines" },
{ name = "pandas" },
{ name = "pyarrow" },
@@ -2939,7 +2917,6 @@ dataset = [
dataset-viz = [
{ name = "av" },
{ name = "datasets" },
{ name = "foxglove-sdk" },
{ name = "jsonlines" },
{ name = "pandas" },
{ name = "pyarrow" },
@@ -3206,7 +3183,6 @@ video-benchmark = [
{ name = "scikit-image" },
]
viz = [
{ name = "foxglove-sdk" },
{ name = "rerun-sdk" },
]
vla-jepa = [
@@ -3246,7 +3222,6 @@ requires-dist = [
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
{ name = "foxglove-sdk", marker = "extra == 'viz'", specifier = ">=0.25.1,<0.26.0" },
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = ">=1.73.1,<2.0.0" },
{ name = "grpcio", marker = "extra == 'reachy2'", specifier = "<=1.73.1" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.73.1,<2.0.0" },