feat(features names and color): improving features names and display colors when replaying an episode

This commit is contained in:
CarolinePascal
2026-06-10 18:07:32 +02:00
parent 9c502e204e
commit 2c47217825
+34 -17
View File
@@ -62,6 +62,7 @@ local$ rerun rerun+http://IP:GRPC_PORT/proxy
"""
import argparse
import colorsys
import gc
import logging
import time
@@ -77,6 +78,30 @@ from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.utils import init_logging
def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]:
"""Return per-dimension names for a feature from the dataset metadata.
Falls back to ``{key}_{i}`` indices when the metadata has no names.
"""
feature = dataset.features[key]
names = feature.get("names")
if names is not None:
return [str(name) for name in names]
return [f"{key}_{d}" for d in range(feature["shape"][-1])]
def get_sequential_colors(num_dims: int) -> list[list[int]]:
"""Return a deterministic list of distinct RGB colors, one per dimension.
"""
colors = []
for d in range(num_dims):
hue = d / max(num_dims, 1)
r, g, b = colorsys.hsv_to_rgb(hue, 0.7, 0.9)
colors.append([int(r * 255), int(g * 255), int(b * 255)])
return colors
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
@@ -138,28 +163,20 @@ def visualize_dataset(
logging.info("Logging to Rerun")
# Name each series once (static) so all dimensions share a single view while keeping labels.
# Labels come straight from the dataset metadata.
if ACTION in dataset.features:
names = get_feature_names(dataset, ACTION)
rr.log(ACTION, rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True)
if OBS_STATE in dataset.features:
names = get_feature_names(dataset, OBS_STATE)
rr.log("state", rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True)
first_index = None
series_names_logged = False
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
if first_index is None:
first_index = batch["index"][0].item()
# Name each series once (static) so all dimensions share a single view while keeping labels.
if not series_names_logged:
if ACTION in batch:
rr.log(
ACTION,
rr.SeriesLines(names=[f"{ACTION}_{d}" for d in range(batch[ACTION].shape[-1])]),
static=True,
)
if OBS_STATE in batch:
rr.log(
"state",
rr.SeriesLines(names=[f"state_{d}" for d in range(batch[OBS_STATE].shape[-1])]),
static=True,
)
series_names_logged = True
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)