mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 17:17:01 +00:00
Make Foxglove server host bindable and refactor topic/channel handling
Pass display_ip through as the Foxglove WebSocket bind host (127.0.0.1 for local only, 0.0.0.0 for all interfaces) instead of always binding locally. In lerobot-dataset-viz, fold the separate --port into --web-port so one flag covers both the Rerun web viewer and the Foxglove server port. Add a _foxglove_topic() helper and thread a per-topic channel cache through the log helpers so dataset playback stays self-contained instead of mutating the module-global cache. Promote SUCCESS to constants.py.
This commit is contained in:
committed by
CarolinePascal
parent
802f49438c
commit
17c83a7330
@@ -157,14 +157,13 @@ def visualize_dataset(
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 0,
|
||||
mode: str = "local",
|
||||
web_port: int = 9090,
|
||||
web_port: int | None = None,
|
||||
grpc_port: int = 9876,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
display_mode: str = "rerun",
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
**kwargs,
|
||||
) -> Path | None:
|
||||
if display_mode == "foxglove":
|
||||
@@ -177,7 +176,7 @@ def visualize_dataset(
|
||||
dataset,
|
||||
episode_index,
|
||||
host=host,
|
||||
port=port,
|
||||
port=web_port if web_port is not None else 8765,
|
||||
compress_images=display_compressed_images,
|
||||
)
|
||||
return None
|
||||
@@ -218,7 +217,9 @@ def visualize_dataset(
|
||||
if mode == "distant":
|
||||
server_uri = rr.serve_grpc(grpc_port=grpc_port)
|
||||
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
|
||||
rr.serve_web_viewer(
|
||||
open_browser=False, web_port=web_port if web_port is not None else 9090, connect_to=server_uri
|
||||
)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
@@ -342,8 +343,11 @@ def main():
|
||||
parser.add_argument(
|
||||
"--web-port",
|
||||
type=int,
|
||||
default=9090,
|
||||
help="Web port for rerun.io when `--mode distant` is set.",
|
||||
default=None,
|
||||
help=(
|
||||
"Web/WebSocket port. For rerun `--mode distant` it is the web viewer port (default 9090); "
|
||||
"for `--display-mode foxglove` it is the server bind port (default 8765)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
@@ -392,20 +396,17 @@ def main():
|
||||
help=(
|
||||
"Visualization backend. 'rerun' uses the Rerun viewer (--mode/--save/--*-port apply). "
|
||||
"'foxglove' starts a Foxglove WebSocket server that serves the episode as a seekable, "
|
||||
"scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--port)."
|
||||
"scrubbable timeline; connect the Foxglove app to ws://HOST:PORT (--host/--web-port)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8765,
|
||||
help="Port to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set.",
|
||||
help=(
|
||||
"Host to bind the Foxglove WebSocket server to when `--display-mode foxglove` is set "
|
||||
"(127.0.0.1 for local only, 0.0.0.0 for all interfaces)."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -176,11 +176,11 @@ class RecordConfig:
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Visualization backend used when display_data is True: "rerun" or "foxglove".
|
||||
# "foxglove" starts a WebSocket server (default ws://127.0.0.1:8765) to stream data to the Foxglove app.
|
||||
display_mode: str = "rerun"
|
||||
# For "rerun": IP of a remote Rerun server to connect to. Unused by "foxglove".
|
||||
# For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket
|
||||
# server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces).
|
||||
display_ip: str | None = None
|
||||
# For "rerun": port of the remote Rerun server. For "foxglove": port to bind the WebSocket server to.
|
||||
# For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to.
|
||||
display_port: int | None = None
|
||||
# Whether to display compressed (JPEG) images instead of raw frames
|
||||
display_compressed_images: bool = False
|
||||
|
||||
@@ -142,11 +142,11 @@ class TeleoperateConfig:
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Visualization backend used when display_data is True: "rerun" or "foxglove".
|
||||
# "foxglove" starts a WebSocket server (default ws://127.0.0.1:8765) to stream data to the Foxglove app.
|
||||
display_mode: str = "rerun"
|
||||
# For "rerun": IP of a remote Rerun server to connect to. Unused by "foxglove".
|
||||
# For "rerun": IP of a remote server to send to. For "foxglove": interface to bind the WebSocket
|
||||
# server to (127.0.0.1 for local only, 0.0.0.0 for all interfaces).
|
||||
display_ip: str | None = None
|
||||
# For "rerun": port of the remote Rerun server. For "foxglove": port to bind the WebSocket server to.
|
||||
# For "rerun": port of the remote server. For "foxglove": port to bind the WebSocket server to.
|
||||
display_port: int | None = None
|
||||
# Whether to display compressed (JPEG) images instead of raw frames
|
||||
display_compressed_images: bool = False
|
||||
|
||||
@@ -37,6 +37,7 @@ ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
SUCCESS = "next.success"
|
||||
INFO = "info"
|
||||
|
||||
ROBOTS = "robots"
|
||||
|
||||
@@ -20,15 +20,40 @@ import numpy as np
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, DONE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
|
||||
from .constants import ACTION, ACTION_PREFIX, DONE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD, SUCCESS
|
||||
from .import_utils import require_package
|
||||
|
||||
# Visualization backends selectable at runtime via a display-mode string (e.g. a --display_mode flag).
|
||||
VISUALIZATION_MODES = ("rerun", "foxglove")
|
||||
|
||||
# Module-level Foxglove state. A single WebSocket server is shared for the
|
||||
# process lifetime, and image channels are cached by topic (the Foxglove SDK
|
||||
# requires reusing one channel per topic).
|
||||
_foxglove_server = None
|
||||
_foxglove_channels: dict = {}
|
||||
|
||||
# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}``
|
||||
# pairs rather than one field per feature, so the same schema fits any robot regardless of which
|
||||
# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name
|
||||
# each series automatically, so a single filtered path plots every feature, e.g.
|
||||
# ``/observation/state.scalars[:]``.
|
||||
_SCALARS_SCHEMA = {
|
||||
"type": "object",
|
||||
"title": "lerobot.Scalars",
|
||||
"properties": {
|
||||
"scalars": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {"type": "string"},
|
||||
"value": {"type": "number"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def init_rerun(
|
||||
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
|
||||
@@ -148,38 +173,32 @@ def _foxglove_safe_name(name: str) -> str:
|
||||
return name.replace(".", "_")
|
||||
|
||||
|
||||
# Static schema shared by all scalar topics. Each message carries a flat list of ``{label, value}``
|
||||
# pairs rather than one field per feature, so the same schema fits any robot regardless of which
|
||||
# observation/action features it reports. The ``label`` field name is what Foxglove looks for to name
|
||||
# each series automatically, so a single filtered path plots every feature, e.g.
|
||||
# ``/observation/state.scalars[:].value``.
|
||||
_SCALARS_SCHEMA = {
|
||||
"type": "object",
|
||||
"title": "lerobot.Scalars",
|
||||
"properties": {
|
||||
"scalars": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {"type": "string"},
|
||||
"value": {"type": "number"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
def _foxglove_topic(key: str, *, is_image: bool = False) -> str:
|
||||
"""Build the Foxglove topic for a feature ``key``.
|
||||
|
||||
Camera features map to a per-source image topic (``/observation/images/<name>``); scalar features
|
||||
share one aggregate topic per source: ``/observation/state`` for observations, ``/action/state``
|
||||
for actions.
|
||||
"""
|
||||
|
||||
if is_image:
|
||||
name = key[len(OBS_PREFIX) :] if str(key).startswith(OBS_PREFIX) else str(key)
|
||||
return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}"
|
||||
source = ACTION if (str(key).startswith(ACTION_PREFIX) or str(key) == ACTION) else OBS_STR
|
||||
return f"/{source}/state"
|
||||
|
||||
|
||||
def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int | None = None) -> None:
|
||||
def _log_foxglove_scalars(
|
||||
topic: str, values: dict[str, float], *, channels: dict | None = None, log_time: int | None = None
|
||||
) -> None:
|
||||
"""Log scalars on a typed JSON channel using the static :data:`_SCALARS_SCHEMA`.
|
||||
|
||||
``values`` is an ordered mapping of feature name to value; it is emitted as a ``scalars`` array of
|
||||
``{label, value}`` objects. Insertion order is preserved so series stay stable across messages.
|
||||
|
||||
``log_time`` is the message time in nanoseconds. When ``None`` the server's receive time is used
|
||||
(correct for live streaming); dataset playback passes the frame's dataset timestamp so the
|
||||
Foxglove timeline reflects the recorded episode.
|
||||
``channels`` is the per-topic channel cache to reuse (defaults to the module-global cache used by
|
||||
live streaming; dataset playback passes its own local cache to stay self-contained). ``log_time``
|
||||
is the message time in nanoseconds; when ``None`` the server's receive time is used.
|
||||
"""
|
||||
|
||||
if not values:
|
||||
@@ -187,11 +206,11 @@ def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int
|
||||
|
||||
import foxglove
|
||||
|
||||
channel = _foxglove_channels.get(topic)
|
||||
if channels is None:
|
||||
channels = _foxglove_channels
|
||||
channel = channels.get(topic)
|
||||
if channel is None:
|
||||
channel = _foxglove_channels[topic] = foxglove.Channel(
|
||||
topic, schema=_SCALARS_SCHEMA, message_encoding="json"
|
||||
)
|
||||
channel = channels[topic] = foxglove.Channel(topic, schema=_SCALARS_SCHEMA, message_encoding="json")
|
||||
msg = {"scalars": [{"label": label, "value": value} for label, value in values.items()]}
|
||||
if log_time is None:
|
||||
channel.log(msg)
|
||||
@@ -200,47 +219,57 @@ def _log_foxglove_scalars(topic: str, values: dict[str, float], *, log_time: int
|
||||
|
||||
|
||||
def _log_foxglove_image(
|
||||
topic: str, frame_id: str, arr: np.ndarray, *, compress_images: bool, time_ns: int
|
||||
topic: str,
|
||||
frame_id: str,
|
||||
arr: np.ndarray,
|
||||
*,
|
||||
compress_images: bool,
|
||||
channels: dict | None = None,
|
||||
log_time: int | None = None,
|
||||
) -> None:
|
||||
"""Log an image on a cached per-topic channel, stamped at ``time_ns`` (nanoseconds).
|
||||
"""Log an image on a cached per-topic channel.
|
||||
|
||||
``arr`` may be HWC or CHW; CHW is transposed to HWC. ``time_ns`` sets both the message header
|
||||
timestamp and the channel ``log_time`` so the message lands at the right point on the Foxglove
|
||||
timeline (matching wall-clock for live streaming, or the dataset timestamp during playback).
|
||||
``arr`` may be HWC or CHW; CHW is transposed to HWC. ``channels`` is the per-topic channel cache
|
||||
to reuse (see :func:`_log_foxglove_scalars`). ``log_time`` is the message time in nanoseconds; when
|
||||
``None`` the server's receive time is used. It is also written to the message header timestamp.
|
||||
"""
|
||||
|
||||
from foxglove.channels import CompressedImageChannel, RawImageChannel
|
||||
from foxglove.messages import CompressedImage, RawImage, Timestamp
|
||||
|
||||
if channels is None:
|
||||
channels = _foxglove_channels
|
||||
time_ns = time.time_ns() if log_time is None else log_time
|
||||
timestamp = Timestamp(sec=time_ns // 1_000_000_000, nsec=time_ns % 1_000_000_000)
|
||||
log_kwargs = {} if log_time is None else {"log_time": log_time}
|
||||
|
||||
# Convert CHW -> HWC when needed (mirrors log_rerun_data).
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
height, width = arr.shape[0], arr.shape[1]
|
||||
channels = 1 if arr.ndim == 2 else arr.shape[2]
|
||||
n_channels = 1 if arr.ndim == 2 else arr.shape[2]
|
||||
|
||||
if compress_images and channels == 3:
|
||||
if compress_images and n_channels == 3:
|
||||
import cv2
|
||||
|
||||
# Camera frames are RGB; cv2.imencode assumes BGR, so swap to keep colors correct.
|
||||
_, buf = cv2.imencode(".jpg", cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
|
||||
channel = _foxglove_channels.get(topic)
|
||||
channel = channels.get(topic)
|
||||
if channel is None:
|
||||
channel = _foxglove_channels[topic] = CompressedImageChannel(topic=topic)
|
||||
channel = channels[topic] = CompressedImageChannel(topic=topic)
|
||||
channel.log(
|
||||
CompressedImage(timestamp=timestamp, frame_id=frame_id, data=buf.tobytes(), format="jpeg"),
|
||||
log_time=time_ns,
|
||||
**log_kwargs,
|
||||
)
|
||||
return
|
||||
|
||||
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(channels)
|
||||
encoding = {1: "mono8", 3: "rgb8", 4: "rgba8"}.get(n_channels)
|
||||
if encoding is None:
|
||||
return
|
||||
arr = np.ascontiguousarray(arr, dtype=np.uint8)
|
||||
channel = _foxglove_channels.get(topic)
|
||||
channel = channels.get(topic)
|
||||
if channel is None:
|
||||
channel = _foxglove_channels[topic] = RawImageChannel(topic=topic)
|
||||
channel = channels[topic] = RawImageChannel(topic=topic)
|
||||
channel.log(
|
||||
RawImage(
|
||||
timestamp=timestamp,
|
||||
@@ -248,10 +277,10 @@ def _log_foxglove_image(
|
||||
width=width,
|
||||
height=height,
|
||||
encoding=encoding,
|
||||
step=width * channels,
|
||||
step=width * n_channels,
|
||||
data=arr.tobytes(),
|
||||
),
|
||||
log_time=time_ns,
|
||||
**log_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -365,9 +394,6 @@ def log_foxglove_data(
|
||||
|
||||
now = time.time_ns()
|
||||
|
||||
def log_image(topic: str, frame_id: str, arr: np.ndarray) -> None:
|
||||
_log_foxglove_image(topic, frame_id, arr, compress_images=compress_images, time_ns=now)
|
||||
|
||||
if observation:
|
||||
obs_scalars: dict[str, float] = {}
|
||||
for k, v in observation.items():
|
||||
@@ -381,9 +407,14 @@ def log_foxglove_data(
|
||||
for i, vi in enumerate(v):
|
||||
obs_scalars[f"{key}_{i}"] = float(vi)
|
||||
else:
|
||||
# Image topics still sanitize the name since it's used as a topic-path segment.
|
||||
log_image(f"/{OBS_STR}/images/{_foxglove_safe_name(key)}", key, v)
|
||||
_log_foxglove_scalars(f"/{OBS_STR}/state", obs_scalars)
|
||||
_log_foxglove_image(
|
||||
_foxglove_topic(k, is_image=True),
|
||||
key,
|
||||
v,
|
||||
compress_images=compress_images,
|
||||
log_time=now,
|
||||
)
|
||||
_log_foxglove_scalars(_foxglove_topic(OBS_STATE), obs_scalars, log_time=now)
|
||||
|
||||
if action:
|
||||
action_scalars: dict[str, float] = {}
|
||||
@@ -396,7 +427,7 @@ def log_foxglove_data(
|
||||
elif isinstance(v, np.ndarray):
|
||||
for i, vi in enumerate(v.flatten()):
|
||||
action_scalars[f"{key}_{i}"] = float(vi)
|
||||
_log_foxglove_scalars(f"/{ACTION}/state", action_scalars)
|
||||
_log_foxglove_scalars(_foxglove_topic(ACTION), action_scalars, log_time=now)
|
||||
|
||||
|
||||
# ── Dataset playback over a Foxglove WebSocket server ─────────────────────
|
||||
@@ -404,8 +435,6 @@ def log_foxglove_data(
|
||||
# advertise a seekable timeline and serve frames on demand for whatever time the user scrubs/plays
|
||||
# to in the Foxglove app. This relies on the SDK's PlaybackControl capability.
|
||||
|
||||
_SUCCESS = "next.success"
|
||||
|
||||
|
||||
def _feature_dim_names(feature: dict | None) -> list[str] | None:
|
||||
"""Best-effort per-dimension series labels for a 1D feature, or ``None`` to fall back to indices.
|
||||
@@ -502,10 +531,8 @@ def serve_foxglove_dataset_playback(
|
||||
OBS_STATE: _feature_dim_names(dataset.meta.features.get(OBS_STATE)),
|
||||
ACTION: _feature_dim_names(dataset.meta.features.get(ACTION)),
|
||||
}
|
||||
|
||||
def topic_for(key: str) -> str:
|
||||
name = key[len(OBS_PREFIX) :] if str(key).startswith(OBS_PREFIX) else str(key)
|
||||
return f"/{OBS_STR}/images/{_foxglove_safe_name(name)}"
|
||||
# Local channel cache so the playback server is self-contained and doesn't touch the module global.
|
||||
channels: dict = {}
|
||||
|
||||
def emit_frame(i: int) -> None:
|
||||
"""Log every channel for frame ``i`` stamped at its dataset timestamp."""
|
||||
@@ -518,21 +545,32 @@ def serve_foxglove_dataset_playback(
|
||||
arr = arr.numpy() if hasattr(arr, "numpy") else np.asarray(arr)
|
||||
if np.issubdtype(arr.dtype, np.floating):
|
||||
arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
|
||||
_log_foxglove_image(topic_for(key), key, arr, compress_images=compress_images, time_ns=log_time)
|
||||
_log_foxglove_image(
|
||||
_foxglove_topic(key, is_image=True),
|
||||
key,
|
||||
arr,
|
||||
compress_images=compress_images,
|
||||
channels=channels,
|
||||
log_time=log_time,
|
||||
)
|
||||
_log_foxglove_scalars(
|
||||
f"/{OBS_STR}/state",
|
||||
_foxglove_topic(OBS_STATE),
|
||||
_frame_to_scalars(sample, OBS_STATE, scalar_labels[OBS_STATE]),
|
||||
channels=channels,
|
||||
log_time=log_time,
|
||||
)
|
||||
_log_foxglove_scalars(
|
||||
f"/{ACTION}/state", _frame_to_scalars(sample, ACTION, scalar_labels[ACTION]), log_time=log_time
|
||||
_foxglove_topic(ACTION),
|
||||
_frame_to_scalars(sample, ACTION, scalar_labels[ACTION]),
|
||||
channels=channels,
|
||||
log_time=log_time,
|
||||
)
|
||||
episode_scalars = {}
|
||||
for feat, label in ((DONE, "done"), (REWARD, "reward"), (_SUCCESS, "success")):
|
||||
for feat, label in ((DONE, "done"), (REWARD, "reward"), (SUCCESS, "success")):
|
||||
v = sample.get(feat)
|
||||
if v is not None:
|
||||
episode_scalars[label] = float(v)
|
||||
_log_foxglove_scalars("/episode/state", episode_scalars, log_time=log_time)
|
||||
_log_foxglove_scalars("/episode/state", episode_scalars, channels=channels, log_time=log_time)
|
||||
|
||||
lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
@@ -644,15 +682,13 @@ def serve_foxglove_dataset_playback(
|
||||
stop_event.set()
|
||||
thread.join(timeout=2.0)
|
||||
server.stop()
|
||||
_foxglove_channels.clear()
|
||||
channels.clear()
|
||||
|
||||
|
||||
# ── Backend-agnostic dispatch ─────────────────────────────────────────────
|
||||
# These let callers select a visualization backend at runtime via a string
|
||||
# (e.g. a `--display_mode` CLI flag) without branching on the backend everywhere.
|
||||
|
||||
VISUALIZATION_MODES = ("rerun", "foxglove")
|
||||
|
||||
|
||||
def init_visualization(
|
||||
display_mode: str,
|
||||
@@ -664,13 +700,14 @@ def init_visualization(
|
||||
"""Initializes the visualization backend selected by ``display_mode``.
|
||||
|
||||
For ``"rerun"``, ``ip``/``port`` point at an optional remote Rerun server. For ``"foxglove"``,
|
||||
``port`` is the local WebSocket server port (``ip`` is ignored; the server binds locally).
|
||||
``ip`` is the interface to bind the WebSocket server to (``127.0.0.1`` for local only, ``0.0.0.0``
|
||||
for all interfaces) and ``port`` is its port.
|
||||
"""
|
||||
|
||||
if display_mode == "rerun":
|
||||
init_rerun(session_name=session_name, ip=ip, port=port)
|
||||
elif display_mode == "foxglove":
|
||||
init_foxglove(port=port)
|
||||
init_foxglove(host=ip or "127.0.0.1", port=port)
|
||||
else:
|
||||
raise ValueError(f"Unknown display_mode '{display_mode}'. Expected one of {VISUALIZATION_MODES}.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user