diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 89006fd57..2c94736ee 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -157,14 +157,13 @@ def visualize_dataset( batch_size: int = 32, num_workers: int = 0, mode: str = "local", - web_port: int = 9090, + web_port: int | None = None, grpc_port: int = 9876, save: bool = False, output_dir: Path | None = None, display_compressed_images: bool = False, display_mode: str = "rerun", host: str = "127.0.0.1", - port: int = 8765, **kwargs, ) -> Path | None: if display_mode == "foxglove": @@ -177,7 +176,7 @@ def visualize_dataset( dataset, episode_index, host=host, - port=port, + port=web_port if web_port is not None else 8765, compress_images=display_compressed_images, ) return None @@ -218,7 +217,9 @@ def visualize_dataset( if mode == "distant": server_uri = rr.serve_grpc(grpc_port=grpc_port) logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy") - rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri) + rr.serve_web_viewer( + open_browser=False, web_port=web_port if web_port is not None else 9090, connect_to=server_uri + ) logging.info("Logging to Rerun") @@ -342,8 +343,11 @@ def main(): parser.add_argument( "--web-port", type=int, - default=9090, - help="Web port for rerun.io when `--mode distant` is set.", + default=None, + help=( + "Web/WebSocket port. For rerun `--mode distant` it is the web viewer port (default 9090); " + "for `--display-mode foxglove` it is the server bind port (default 8765)." + ), ) parser.add_argument( "--ws-port", @@ -392,20 +396,17 @@ def main(): help=( "Visualization backend. 'rerun' uses the Rerun viewer (--mode/--save/--*-port apply). " "'foxglove' starts a Foxglove WebSocket server that serves the episode as a seekable, " - "scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--port)." + "scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--web-port)." ), ) parser.add_argument( "--host", type=str, default="127.0.0.1", - help="Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.", - ) - parser.add_argument( - "--port", - type=int, - default=8765, - help="Port to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.", + help=( + "Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set " + "(127.0.0.1 for local only, 0.0.0.0 for all interfaces)." + ), ) args = parser.parse_args() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index b9012811a..4233f0bc2 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -176,11 +176,11 @@ class RecordConfig: # Display all cameras on screen display_data: bool = False # Visualization backend used when display_data is True: "rerun" or "foxglove". - # "foxglove" starts a WebSocket server (default ws://127.0.0.1:8765) to stream data to the Foxglove app. display_mode: str = "rerun" - # For "rerun": IP of a remote Rerun server to connect to. Unused by "foxglove". + # For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket + # server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces). display_ip: str | None = None - # For "rerun": port of the remote Rerun server. For "foxglove": port to bind the WebSocket server to. + # For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to. display_port: int | None = None # Whether to display compressed (JPEG) images instead of raw frames display_compressed_images: bool = False diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 924e92c8a..0c7f49010 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -142,11 +142,11 @@ class TeleoperateConfig: # Display all cameras on screen display_data: bool = False # Visualization backend used when display_data is True: "rerun" or "foxglove". - # "foxglove" starts a WebSocket server (default ws://127.0.0.1:8765) to stream data to the Foxglove app. display_mode: str = "rerun" - # For "rerun": IP of a remote Rerun server to connect to. Unused by "foxglove". + # For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket + # server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces). display_ip: str | None = None - # For "rerun": port of the remote Rerun server. For "foxglove": port to bind the WebSocket server to. + # For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to. display_port: int | None = None # Whether to display compressed (JPEG) images instead of raw frames display_compressed_images: bool = False diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 482394ff6..8f735fe6d 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -37,6 +37,7 @@ ACTION_TOKEN_MASK = ACTION + ".token_mask" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" +SUCCESS = "next.success" INFO = "info" ROBOTS = "robots" diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 335623d64..bc988ecf4 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -20,15 +20,40 @@ import numpy as np from lerobot.types import RobotAction, RobotObservation -from .constants import ACTION, ACTION_PREFIX, DONE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD +from .constants import ACTION, ACTION_PREFIX, DONE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD, SUCCESS from .import_utils import require_package +# Visualization backends selectable at runtime via a display-mode string (e.g. a --display_mode flag). +VISUALIZATION_MODES = ("rerun", "foxglove") + # Module-level Foxglove state. A single WebSocket server is shared for the # process lifetime, and image channels are cached by topic (the Foxglove SDK # requires reusing one channel per topic). _foxglove_server = None _foxglove_channels: dict = {} +# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}`` +# pairs rather than one field per feature, so the same schema fits any robot regardless of which +# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name +# each series automatically, so a single filtered path plots every feature, e.g. +# ``/observation/state.scalars[:]``. +_SCALARS_SCHEMA = { + "type": "object", + "title": "lerobot.Scalars", + "properties": { + "scalars": { + "type": "array", + "items": { + "type": "object", + "properties": { + "label": {"type": "string"}, + "value": {"type": "number"}, + }, + }, + } + }, +} + def init_rerun( session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None @@ -148,38 +173,32 @@ def _foxglove_safe_name(name: str) -> str: return name.replace(".", "_") -# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}`` -# pairs rather than one field per feature, so the same schema fits any robot regardless of which -# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name -# each series automatically, so a single filtered path plots every feature, e.g. -# ``/observation/state.scalars[:].value``. -_SCALARS_SCHEMA = { - "type": "object", - "title": "lerobot.Scalars", - "properties": { - "scalars": { - "type": "array", - "items": { - "type": "object", - "properties": { - "label": {"type": "string"}, - "value": {"type": "number"}, - }, - }, - } - }, -} +def _foxglove_topic(key: str, *, is_image: bool = False) -> str: + """Build the Foxglove topic for a feature ``key``. + + Camera features map to a per-source image topic (``/observation/images/``); scalar features + share one aggregate topic per source: ``/observation/state`` for observations, ``/action/state`` + for actions. + """ + + if is_image: + name = key[len(OBS_PREFIX) :] if str(key).startswith(OBS_PREFIX) else str(key) + return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}" + source = ACTION if (str(key).startswith(ACTION_PREFIX) or str(key) == ACTION) else OBS_STR + return f"/{source}/state" -def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int | None = None) -> None: +def _log_foxglove_scalars( + topic: str, values: dict[str, float], *, channels: dict | None = None, log_time: int | None = None +) -> None: """Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`. ``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of ``{label, value}`` objects. Insertion order is preserved so series stay stable across messages. - ``log_time`` is the message time in nanoseconds. When ``None`` the server's receive time is used - (correct for live streaming); dataset playback passes the frame's dataset timestamp so the - Foxglove timeline reflects the recorded episode. + ``channels`` is the per-topic channel cache to reuse (defaults to the module-global cache used by + live streaming; dataset playback passes its own local cache to stay self-contained). ``log_time`` + is the message time in nanoseconds; when ``None`` the server's receive time is used. """ if not values: @@ -187,11 +206,11 @@ def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int import foxglove - channel = _foxglove_channels.get(topic) + if channels is None: + channels = _foxglove_channels + channel = channels.get(topic) if channel is None: - channel = _foxglove_channels[topic] = foxglove.Channel( - topic, schema=_SCALARS_SCHEMA, message_encoding="json" - ) + channel = channels[topic] = foxglove.Channel(topic, schema=_SCALARS_SCHEMA, message_encoding="json") msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]} if log_time is None: channel.log(msg) @@ -200,47 +219,57 @@ def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int def _log_foxglove_image( - topic: str, frame_id: str, arr: np.ndarray, *, compress_images: bool, time_ns: int + topic: str, + frame_id: str, + arr: np.ndarray, + *, + compress_images: bool, + channels: dict | None = None, + log_time: int | None = None, ) -> None: - """Log an image on a cached per-topic channel, stamped at ``time_ns`` (nanoseconds). + """Log an image on a cached per-topic channel. - ``arr`` may be HWC or CHW; CHW is transposed to HWC. ``time_ns`` sets both the message header - timestamp and the channel ``log_time`` so the message lands at the right point on the Foxglove - timeline (matching wall-clock for live streaming, or the dataset timestamp during playback). + ``arr`` may be HWC or CHW; CHW is transposed to HWC. ``channels`` is the per-topic channel cache + to reuse (see :func:`_log_foxglove_scalars`). ``log_time`` is the message time in nanoseconds; when + ``None`` the server's receive time is used. It is also written to the message header timestamp. """ from foxglove.channels import CompressedImageChannel, RawImageChannel from foxglove.messages import CompressedImage, RawImage, Timestamp + if channels is None: + channels = _foxglove_channels + time_ns = time.time_ns() if log_time is None else log_time timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000) + log_kwargs = {} if log_time is None else {"log_time": log_time} # Convert CHW -> HWC when needed (mirrors log_rerun_data). if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): arr = np.transpose(arr, (1, 2, 0)) height, width = arr.shape[0], arr.shape[1] - channels = 1 if arr.ndim == 2 else arr.shape[2] + n_channels = 1 if arr.ndim == 2 else arr.shape[2] - if compress_images and channels == 3: + if compress_images and n_channels == 3: import cv2 # Camera frames are RGB; cv2.imencode assumes BGR, so swap to keep colors correct. _, buf = cv2.imencode(".jpg", cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)) - channel = _foxglove_channels.get(topic) + channel = channels.get(topic) if channel is None: - channel = _foxglove_channels[topic] = CompressedImageChannel(topic=topic) + channel = channels[topic] = CompressedImageChannel(topic=topic) channel.log( CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"), - log_time=time_ns, + **log_kwargs, ) return - encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(channels) + encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(n_channels) if encoding is None: return arr = np.ascontiguousarray(arr, dtype=np.uint8) - channel = _foxglove_channels.get(topic) + channel = channels.get(topic) if channel is None: - channel = _foxglove_channels[topic] = RawImageChannel(topic=topic) + channel = channels[topic] = RawImageChannel(topic=topic) channel.log( RawImage( timestamp=timestamp, @@ -248,10 +277,10 @@ def _log_foxglove_image( width=width, height=height, encoding=encoding, - step=width * channels, + step=width * n_channels, data=arr.tobytes(), ), - log_time=time_ns, + **log_kwargs, ) @@ -365,9 +394,6 @@ def log_foxglove_data( now = time.time_ns() - def log_image(topic: str, frame_id: str, arr: np.ndarray) -> None: - _log_foxglove_image(topic, frame_id, arr, compress_images=compress_images, time_ns=now) - if observation: obs_scalars: dict[str, float] = {} for k, v in observation.items(): @@ -381,9 +407,14 @@ def log_foxglove_data( for i, vi in enumerate(v): obs_scalars[f"{key}_{i}"] = float(vi) else: - # Image topics still sanitize the name since it's used as a topic-path segment. - log_image(f"/{OBS_STR}/images/{_foxglove_safe_name(key)}", key, v) - _log_foxglove_scalars(f"/{OBS_STR}/state", obs_scalars) + _log_foxglove_image( + _foxglove_topic(k, is_image=True), + key, + v, + compress_images=compress_images, + log_time=now, + ) + _log_foxglove_scalars(_foxglove_topic(OBS_STATE), obs_scalars, log_time=now) if action: action_scalars: dict[str, float] = {} @@ -396,7 +427,7 @@ def log_foxglove_data( elif isinstance(v, np.ndarray): for i, vi in enumerate(v.flatten()): action_scalars[f"{key}_{i}"] = float(vi) - _log_foxglove_scalars(f"/{ACTION}/state", action_scalars) + _log_foxglove_scalars(_foxglove_topic(ACTION), action_scalars, log_time=now) # ── Dataset playback over a Foxglove WebSocket server ───────────────────── @@ -404,8 +435,6 @@ def log_foxglove_data( # advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays # to in the Foxglove app. This relies on the SDK's PlaybackControl capability. -_SUCCESS = "next.success" - def _feature_dim_names(feature: dict | None) -> list[str] | None: """Best-effort per-dimension series labels for a 1D feature, or ``None`` to fall back to indices. @@ -502,10 +531,8 @@ def serve_foxglove_dataset_playback( OBS_STATE: _feature_dim_names(dataset.meta.features.get(OBS_STATE)), ACTION: _feature_dim_names(dataset.meta.features.get(ACTION)), } - - def topic_for(key: str) -> str: - name = key[len(OBS_PREFIX) :] if str(key).startswith(OBS_PREFIX) else str(key) - return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}" + # Local channel cache so the playback server is self-contained and doesn't touch the module global. + channels: dict = {} def emit_frame(i: int) -> None: """Log every channel for frame ``i`` stamped at its dataset timestamp.""" @@ -518,21 +545,32 @@ def serve_foxglove_dataset_playback( arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr) if np.issubdtype(arr.dtype, np.floating): arr = (arr * 255.0).clip(0, 255).astype(np.uint8) - _log_foxglove_image(topic_for(key), key, arr, compress_images=compress_images, time_ns=log_time) + _log_foxglove_image( + _foxglove_topic(key, is_image=True), + key, + arr, + compress_images=compress_images, + channels=channels, + log_time=log_time, + ) _log_foxglove_scalars( - f"/{OBS_STR}/state", + _foxglove_topic(OBS_STATE), _frame_to_scalars(sample, OBS_STATE, scalar_labels[OBS_STATE]), + channels=channels, log_time=log_time, ) _log_foxglove_scalars( - f"/{ACTION}/state", _frame_to_scalars(sample, ACTION, scalar_labels[ACTION]), log_time=log_time + _foxglove_topic(ACTION), + _frame_to_scalars(sample, ACTION, scalar_labels[ACTION]), + channels=channels, + log_time=log_time, ) episode_scalars = {} - for feat, label in ((DONE, "done"), (REWARD, "reward"), (_SUCCESS, "success")): + for feat, label in ((DONE, "done"), (REWARD, "reward"), (SUCCESS, "success")): v = sample.get(feat) if v is not None: episode_scalars[label] = float(v) - _log_foxglove_scalars("/episode/state", episode_scalars, log_time=log_time) + _log_foxglove_scalars("/episode/state", episode_scalars, channels=channels, log_time=log_time) lock = threading.Lock() stop_event = threading.Event() @@ -644,15 +682,13 @@ def serve_foxglove_dataset_playback( stop_event.set() thread.join(timeout=2.0) server.stop() - _foxglove_channels.clear() + channels.clear() # ── Backend-agnostic dispatch ───────────────────────────────────────────── # These let callers select a visualization backend at runtime via a string # (e.g. a `--display_mode` CLI flag) without branching on the backend everywhere. -VISUALIZATION_MODES = ("rerun", "foxglove") - def init_visualization( display_mode: str, @@ -664,13 +700,14 @@ def init_visualization( """Initializes the visualization backend selected by ``display_mode``. For ``"rerun"``, ``ip``/``port`` point at an optional remote Rerun server. For ``"foxglove"``, - ``port`` is the local WebSocket server port (``ip`` is ignored; the server binds locally). + ``ip`` is the interface to bind the WebSocket server to (``127.0.0.1`` for local only, ``0.0.0.0`` + for all interfaces) and ``port`` is its port. """ if display_mode == "rerun": init_rerun(session_name=session_name, ip=ip, port=port) elif display_mode == "foxglove": - init_foxglove(port=port) + init_foxglove(host=ip or "127.0.0.1", port=port) else: raise ValueError(f"Unknown display_mode '{display_mode}'. Expected one of {VISUALIZATION_MODES}.")