From fcd8ab580068e30aed8eccb4101877899ba6d467 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 10 Jun 2026 20:25:12 +0200 Subject: [PATCH] fix(claude): claude reviews --- src/lerobot/scripts/lerobot_dataset_viz.py | 10 +++++++--- src/lerobot/utils/visualization_utils.py | 4 +++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index fc4777aa3..b0c93bb18 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -80,14 +80,16 @@ from lerobot.utils.utils import init_logging def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]: """Return per-dimension names for a feature from the dataset metadata. - Falls back to ``{key}_{i}`` indices when the metadata has no names. + Only flat-list ``names`` metadata is used. Dict-style ``names`` and missing names fall back to ``{key}_{i}`` indices. """ feature = dataset.features[key] + dim = feature["shape"][-1] + names = feature.get("names") - if names is not None: + if isinstance(names, list) and len(names) == dim: return [str(name) for name in names] - return [f"{key}_{d}" for d in range(feature["shape"][-1])] + return [f"{key}_{d}" for d in range(dim)] def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: @@ -329,6 +331,8 @@ def main(): ) logging.warning("Setting grpc_port to ws_port value.") kwargs["grpc_port"] = kwargs.pop("ws_port") + else: + kwargs.pop("ws_port") # Always remove ws_port from kwargs init_logging() logging.info("Loading dataset") diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index f48882de1..701daaac4 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -38,6 +38,8 @@ def init_rerun( require_package("rerun-sdk", extra="viz", import_name="rerun") import rerun as rr + log_rerun_data.blueprint = None # Reset blueprint cache for new session + batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size rr.init(session_name) @@ -110,7 +112,7 @@ def log_rerun_data( from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`. - 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every dimension shares the same view instead of being split across one view per element. - - Other multi-dimensional arrays are flattened and logged as a single `rr.Scalars` batch. + - Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch. Keys are automatically namespaced with "observation." or "action." if not already present.