From 73d15e160bb4f0b1b570b157edf4100ff6cb8e08 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Wed, 24 Jun 2026 17:45:54 +0200 Subject: [PATCH] feat(lerobot-dataset-viz): adding support for depth in lerobot-dataset-viz --- src/lerobot/scripts/lerobot_dataset_viz.py | 43 ++++++++++++++++++---- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index d07a2767d..43f992f46 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -77,15 +77,27 @@ from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD from lerobot.utils.utils import init_logging +def check_chw_float32(frame: torch.Tensor) -> None: + """ + Check if a frame is a channel-first, float32 tensor. + """ + assert frame.dtype == torch.float32 + assert frame.ndim == 3 + c, h, w = frame.shape + assert c < h and c < w, f"expect channel first images, but instead {frame.shape}" + def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: - assert chw_float32_torch.dtype == torch.float32 - assert chw_float32_torch.ndim == 3 - c, h, w = chw_float32_torch.shape - assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" + check_chw_float32(chw_float32_torch) hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() return hwc_uint8_numpy +def to_hwc_uint16_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: + check_chw_float32(chw_float32_torch) + hwc_uint16_numpy = chw_float32_torch.round().type(torch.uint16).permute(1, 2, 0).numpy() + return hwc_uint16_numpy + + def visualize_dataset( dataset: LeRobotDataset, episode_index: int, @@ -138,6 +150,14 @@ def visualize_dataset( logging.info("Logging to Rerun") + # Use the dataset's q01/q99 depth statistics for robust depth range bounds + depth_ranges = {} + for key in dataset.meta.depth_keys: + stats = dataset.meta.stats[key] + lo = stats["q01"] if "q01" in stats else stats["min"] + hi = stats["q99"] if "q99" in stats else stats["max"] + depth_ranges[key] = (float(lo), float(hi)) + first_index = None for batch in tqdm.tqdm(dataloader, total=len(dataloader)): if first_index is None: @@ -149,9 +169,18 @@ def visualize_dataset( # display each camera image for key in dataset.meta.camera_keys: - img = to_hwc_uint8_numpy(batch[key][i]) - img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img) - rr.log(key, entity=img_entity) + if key in dataset.meta.depth_keys: + depth = to_hwc_uint16_numpy(batch[key][i]) + depth_entity = rr.DepthImage( + depth, + colormap=rr.components.Colormap.Viridis, + depth_range=depth_ranges[key], + ) + rr.log(key, entity=depth_entity) + else: + img = to_hwc_uint8_numpy(batch[key][i]) + img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img) + rr.log(key, entity=img_entity) # display each dimension of action space (e.g. actuators command) if ACTION in batch: