mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
feat(features names and color): improving features names and display colors when replaying an episode
This commit is contained in:
@@ -62,6 +62,7 @@ local$ rerun rerun+http://IP:GRPC_PORT/proxy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import colorsys
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
@@ -77,6 +78,30 @@ from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]:
|
||||
"""Return per-dimension names for a feature from the dataset metadata.
|
||||
|
||||
Falls back to ``{key}_{i}`` indices when the metadata has no names.
|
||||
"""
|
||||
feature = dataset.features[key]
|
||||
names = feature.get("names")
|
||||
if names is not None:
|
||||
return [str(name) for name in names]
|
||||
|
||||
return [f"{key}_{d}" for d in range(feature["shape"][-1])]
|
||||
|
||||
|
||||
def get_sequential_colors(num_dims: int) -> list[list[int]]:
|
||||
"""Return a deterministic list of distinct RGB colors, one per dimension.
|
||||
"""
|
||||
colors = []
|
||||
for d in range(num_dims):
|
||||
hue = d / max(num_dims, 1)
|
||||
r, g, b = colorsys.hsv_to_rgb(hue, 0.7, 0.9)
|
||||
colors.append([int(r * 255), int(g * 255), int(b * 255)])
|
||||
return colors
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
@@ -138,28 +163,20 @@ def visualize_dataset(
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
# Name each series once (static) so all dimensions share a single view while keeping labels.
|
||||
# Labels come straight from the dataset metadata.
|
||||
if ACTION in dataset.features:
|
||||
names = get_feature_names(dataset, ACTION)
|
||||
rr.log(ACTION, rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True)
|
||||
if OBS_STATE in dataset.features:
|
||||
names = get_feature_names(dataset, OBS_STATE)
|
||||
rr.log("state", rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True)
|
||||
|
||||
first_index = None
|
||||
series_names_logged = False
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
if first_index is None:
|
||||
first_index = batch["index"][0].item()
|
||||
|
||||
# Name each series once (static) so all dimensions share a single view while keeping labels.
|
||||
if not series_names_logged:
|
||||
if ACTION in batch:
|
||||
rr.log(
|
||||
ACTION,
|
||||
rr.SeriesLines(names=[f"{ACTION}_{d}" for d in range(batch[ACTION].shape[-1])]),
|
||||
static=True,
|
||||
)
|
||||
if OBS_STATE in batch:
|
||||
rr.log(
|
||||
"state",
|
||||
rr.SeriesLines(names=[f"state_{d}" for d in range(batch[OBS_STATE].shape[-1])]),
|
||||
static=True,
|
||||
)
|
||||
series_names_logged = True
|
||||
|
||||
# iterate over the batch
|
||||
for i in range(len(batch["index"])):
|
||||
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
|
||||
|
||||
Reference in New Issue
Block a user