diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 535b84e77..66ac5a511 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -92,8 +92,7 @@ def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]: def get_sequential_colors(num_dims: int) -> list[list[int]]: - """Return a deterministic list of distinct RGB colors, one per dimension. - """ + """Return a deterministic list of distinct RGB colors, one per dimension.""" colors = [] for d in range(num_dims): hue = d / max(num_dims, 1) @@ -111,6 +110,41 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: return hwc_uint8_numpy +def build_blueprint_from_dataset(dataset: LeRobotDataset): + """Build a Rerun blueprint laying out camera images and time series for the given dataset. + + Camera images are arranged in a grid on the left, and the available scalar signals + (action, state, reward, done, success) are stacked as time series views on the right. + The per-dimension series names and colors for ``action`` and ``state`` are applied + directly via blueprint overrides. + """ + import rerun as rr + import rerun.blueprint as rrb + + image_views = [rrb.Spatial2DView(origin=key, name=key) for key in dataset.meta.camera_keys] + + timeseries_views = [] + # Style multi-dimensional signals (action, state) with per-dimension names and colors. + for origin, key in ((ACTION, ACTION), ("state", OBS_STATE)): + if key in dataset.features: + names = get_feature_names(dataset, key) + styling = rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))) + timeseries_views.append( + rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling}) + ) + for key in (DONE, REWARD, "next.success"): + if key in dataset.features: + timeseries_views.append(rrb.TimeSeriesView(origin=key, name=key)) + + contents = [] + if image_views: + contents.append(rrb.Grid(*image_views, name="images")) + if timeseries_views: + contents.append(rrb.Vertical(*timeseries_views, name="time series")) + + return rrb.Blueprint(rrb.Horizontal(*contents) if contents else rrb.Grid()) + + def visualize_dataset( dataset: LeRobotDataset, episode_index: int, @@ -149,7 +183,8 @@ def visualize_dataset( import rerun as rr spawn_local_viewer = mode == "local" and not save - rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) + blueprint = build_blueprint_from_dataset(dataset) + rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer, default_blueprint=blueprint) # Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush # when iterating on a dataloader with `num_workers` > 0 @@ -163,36 +198,23 @@ def visualize_dataset( logging.info("Logging to Rerun") - # Name each series once (static) so all dimensions share a single view while keeping labels. - # Labels come straight from the dataset metadata. - if ACTION in dataset.features: - names = get_feature_names(dataset, ACTION) - rr.log(ACTION, rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True) - if OBS_STATE in dataset.features: - names = get_feature_names(dataset, OBS_STATE) - rr.log("state", rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True) - first_index = None for batch in tqdm.tqdm(dataloader, total=len(dataloader)): if first_index is None: first_index = batch["index"][0].item() - # iterate over the batch for i in range(len(batch["index"])): rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index) rr.set_time("timestamp", timestamp=batch["timestamp"][i].item()) - # display each camera image for key in dataset.meta.camera_keys: img = to_hwc_uint8_numpy(batch[key][i]) img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img) rr.log(key, entity=img_entity) - # display each dimension of action space (e.g. actuators command) if ACTION in batch: rr.log(ACTION, rr.Scalars(batch[ACTION][i].numpy())) - # display each dimension of observed state space (e.g. agent position in joint space) if OBS_STATE in batch: rr.log("state", rr.Scalars(batch[OBS_STATE][i].numpy())) @@ -206,8 +228,6 @@ def visualize_dataset( rr.log("next.success", rr.Scalars(batch["next.success"][i].item())) if mode == "local" and save: - # save .rrd locally - output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) repo_id_str = repo_id.replace("/", "_") rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd" @@ -215,7 +235,7 @@ def visualize_dataset( return rrd_path elif mode == "distant": - # stop the process from exiting since it is serving the websocket connection + # Keep the process alive while it serves the gRPC/web connection. try: while True: time.sleep(1) @@ -335,7 +355,7 @@ def main(): logging.info("Loading dataset") dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s) - visualize_dataset(dataset, **vars(args)) + visualize_dataset(dataset, **kwargs) if __name__ == "__main__":