mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-25 12:17:08 +00:00
feat(lerobot-dataset-viz): adding support for depth in lerobot-dataset-viz
This commit is contained in:
@@ -77,15 +77,27 @@ from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def check_chw_float32(frame: torch.Tensor) -> None:
|
||||
"""
|
||||
Check if a frame is a channel-first, float32 tensor.
|
||||
"""
|
||||
assert frame.dtype == torch.float32
|
||||
assert frame.ndim == 3
|
||||
c, h, w = frame.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
check_chw_float32(chw_float32_torch)
|
||||
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
return hwc_uint8_numpy
|
||||
|
||||
|
||||
def to_hwc_uint16_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
check_chw_float32(chw_float32_torch)
|
||||
hwc_uint16_numpy = chw_float32_torch.round().type(torch.uint16).permute(1, 2, 0).numpy()
|
||||
return hwc_uint16_numpy
|
||||
|
||||
|
||||
def visualize_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
@@ -138,6 +150,14 @@ def visualize_dataset(
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
# Use the dataset's q01/q99 depth statistics for robust depth range bounds
|
||||
depth_ranges = {}
|
||||
for key in dataset.meta.depth_keys:
|
||||
stats = dataset.meta.stats[key]
|
||||
lo = stats["q01"] if "q01" in stats else stats["min"]
|
||||
hi = stats["q99"] if "q99" in stats else stats["max"]
|
||||
depth_ranges[key] = (float(lo), float(hi))
|
||||
|
||||
first_index = None
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
if first_index is None:
|
||||
@@ -149,9 +169,18 @@ def visualize_dataset(
|
||||
|
||||
# display each camera image
|
||||
for key in dataset.meta.camera_keys:
|
||||
img = to_hwc_uint8_numpy(batch[key][i])
|
||||
img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img)
|
||||
rr.log(key, entity=img_entity)
|
||||
if key in dataset.meta.depth_keys:
|
||||
depth = to_hwc_uint16_numpy(batch[key][i])
|
||||
depth_entity = rr.DepthImage(
|
||||
depth,
|
||||
colormap=rr.components.Colormap.Viridis,
|
||||
depth_range=depth_ranges[key],
|
||||
)
|
||||
rr.log(key, entity=depth_entity)
|
||||
else:
|
||||
img = to_hwc_uint8_numpy(batch[key][i])
|
||||
img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img)
|
||||
rr.log(key, entity=img_entity)
|
||||
|
||||
# display each dimension of action space (e.g. actuators command)
|
||||
if ACTION in batch:
|
||||
|
||||
Reference in New Issue
Block a user