feat(lerobot-dataset-viz): adding support for depth in lerobot-dataset-viz

This commit is contained in:
CarolinePascal
2026-06-24 17:45:54 +02:00
parent 1ba225c4ca
commit 73d15e160b
+36 -7
View File
@@ -77,15 +77,27 @@ from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.utils import init_logging
def check_chw_float32(frame: torch.Tensor) -> None:
"""
Check if a frame is a channel-first, float32 tensor.
"""
assert frame.dtype == torch.float32
assert frame.ndim == 3
c, h, w = frame.shape
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
check_chw_float32(chw_float32_torch)
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
return hwc_uint8_numpy
def to_hwc_uint16_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
check_chw_float32(chw_float32_torch)
hwc_uint16_numpy = chw_float32_torch.round().type(torch.uint16).permute(1, 2, 0).numpy()
return hwc_uint16_numpy
def visualize_dataset(
dataset: LeRobotDataset,
episode_index: int,
@@ -138,6 +150,14 @@ def visualize_dataset(
logging.info("Logging to Rerun")
# Use the dataset's q01/q99 depth statistics for robust depth range bounds
depth_ranges = {}
for key in dataset.meta.depth_keys:
stats = dataset.meta.stats[key]
lo = stats["q01"] if "q01" in stats else stats["min"]
hi = stats["q99"] if "q99" in stats else stats["max"]
depth_ranges[key] = (float(lo), float(hi))
first_index = None
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
if first_index is None:
@@ -149,9 +169,18 @@ def visualize_dataset(
# display each camera image
for key in dataset.meta.camera_keys:
img = to_hwc_uint8_numpy(batch[key][i])
img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img)
rr.log(key, entity=img_entity)
if key in dataset.meta.depth_keys:
depth = to_hwc_uint16_numpy(batch[key][i])
depth_entity = rr.DepthImage(
depth,
colormap=rr.components.Colormap.Viridis,
depth_range=depth_ranges[key],
)
rr.log(key, entity=depth_entity)
else:
img = to_hwc_uint8_numpy(batch[key][i])
img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img)
rr.log(key, entity=img_entity)
# display each dimension of action space (e.g. actuators command)
if ACTION in batch: