From 2c472178256ec40566b837696d62a3a4509757d5 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 10 Jun 2026 18:07:32 +0200 Subject: [PATCH] feat(features names and color): improving features names and display colors when replaying an episode --- src/lerobot/scripts/lerobot_dataset_viz.py | 51 ++++++++++++++-------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index a7c0b774e..535b84e77 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -62,6 +62,7 @@ local$ rerun rerun+http://IP:GRPC_PORT/proxy """ import argparse +import colorsys import gc import logging import time @@ -77,6 +78,30 @@ from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD from lerobot.utils.utils import init_logging +def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]: + """Return per-dimension names for a feature from the dataset metadata. + + Falls back to ``{key}_{i}`` indices when the metadata has no names. + """ + feature = dataset.features[key] + names = feature.get("names") + if names is not None: + return [str(name) for name in names] + + return [f"{key}_{d}" for d in range(feature["shape"][-1])] + + +def get_sequential_colors(num_dims: int) -> list[list[int]]: + """Return a deterministic list of distinct RGB colors, one per dimension. + """ + colors = [] + for d in range(num_dims): + hue = d / max(num_dims, 1) + r, g, b = colorsys.hsv_to_rgb(hue, 0.7, 0.9) + colors.append([int(r * 255), int(g * 255), int(b * 255)]) + return colors + + def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 @@ -138,28 +163,20 @@ def visualize_dataset( logging.info("Logging to Rerun") + # Name each series once (static) so all dimensions share a single view while keeping labels. + # Labels come straight from the dataset metadata. + if ACTION in dataset.features: + names = get_feature_names(dataset, ACTION) + rr.log(ACTION, rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True) + if OBS_STATE in dataset.features: + names = get_feature_names(dataset, OBS_STATE) + rr.log("state", rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True) + first_index = None - series_names_logged = False for batch in tqdm.tqdm(dataloader, total=len(dataloader)): if first_index is None: first_index = batch["index"][0].item() - # Name each series once (static) so all dimensions share a single view while keeping labels. - if not series_names_logged: - if ACTION in batch: - rr.log( - ACTION, - rr.SeriesLines(names=[f"{ACTION}_{d}" for d in range(batch[ACTION].shape[-1])]), - static=True, - ) - if OBS_STATE in batch: - rr.log( - "state", - rr.SeriesLines(names=[f"state_{d}" for d in range(batch[OBS_STATE].shape[-1])]), - static=True, - ) - series_names_logged = True - # iterate over the batch for i in range(len(batch["index"])): rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)