feat(blueprints): switching to blueprints for backwards (and forward) compatibiltiy

This commit is contained in:
CarolinePascal
2026-06-10 18:51:03 +02:00
parent 2c47217825
commit dabf88ef9f
+40 -20
View File
@@ -92,8 +92,7 @@ def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]:
def get_sequential_colors(num_dims: int) -> list[list[int]]:
"""Return a deterministic list of distinct RGB colors, one per dimension.
"""
"""Return a deterministic list of distinct RGB colors, one per dimension."""
colors = []
for d in range(num_dims):
hue = d / max(num_dims, 1)
@@ -111,6 +110,41 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
return hwc_uint8_numpy
def build_blueprint_from_dataset(dataset: LeRobotDataset):
"""Build a Rerun blueprint laying out camera images and time series for the given dataset.
Camera images are arranged in a grid on the left, and the available scalar signals
(action, state, reward, done, success) are stacked as time series views on the right.
The per-dimension series names and colors for ``action`` and ``state`` are applied
directly via blueprint overrides.
"""
import rerun as rr
import rerun.blueprint as rrb
image_views = [rrb.Spatial2DView(origin=key, name=key) for key in dataset.meta.camera_keys]
timeseries_views = []
# Style multi-dimensional signals (action, state) with per-dimension names and colors.
for origin, key in ((ACTION, ACTION), ("state", OBS_STATE)):
if key in dataset.features:
names = get_feature_names(dataset, key)
styling = rr.SeriesLines(names=names, colors=get_sequential_colors(len(names)))
timeseries_views.append(
rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling})
)
for key in (DONE, REWARD, "next.success"):
if key in dataset.features:
timeseries_views.append(rrb.TimeSeriesView(origin=key, name=key))
contents = []
if image_views:
contents.append(rrb.Grid(*image_views, name="images"))
if timeseries_views:
contents.append(rrb.Vertical(*timeseries_views, name="time series"))
return rrb.Blueprint(rrb.Horizontal(*contents) if contents else rrb.Grid())
def visualize_dataset(
dataset: LeRobotDataset,
episode_index: int,
@@ -149,7 +183,8 @@ def visualize_dataset(
import rerun as rr
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
blueprint = build_blueprint_from_dataset(dataset)
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer, default_blueprint=blueprint)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
@@ -163,36 +198,23 @@ def visualize_dataset(
logging.info("Logging to Rerun")
# Name each series once (static) so all dimensions share a single view while keeping labels.
# Labels come straight from the dataset metadata.
if ACTION in dataset.features:
names = get_feature_names(dataset, ACTION)
rr.log(ACTION, rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True)
if OBS_STATE in dataset.features:
names = get_feature_names(dataset, OBS_STATE)
rr.log("state", rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True)
first_index = None
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
if first_index is None:
first_index = batch["index"][0].item()
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
# display each camera image
for key in dataset.meta.camera_keys:
img = to_hwc_uint8_numpy(batch[key][i])
img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img)
rr.log(key, entity=img_entity)
# display each dimension of action space (e.g. actuators command)
if ACTION in batch:
rr.log(ACTION, rr.Scalars(batch[ACTION][i].numpy()))
# display each dimension of observed state space (e.g. agent position in joint space)
if OBS_STATE in batch:
rr.log("state", rr.Scalars(batch[OBS_STATE][i].numpy()))
@@ -206,8 +228,6 @@ def visualize_dataset(
rr.log("next.success", rr.Scalars(batch["next.success"][i].item()))
if mode == "local" and save:
# save .rrd locally
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repo_id_str = repo_id.replace("/", "_")
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
@@ -215,7 +235,7 @@ def visualize_dataset(
return rrd_path
elif mode == "distant":
# stop the process from exiting since it is serving the websocket connection
# Keep the process alive while it serves the gRPC/web connection.
try:
while True:
time.sleep(1)
@@ -335,7 +355,7 @@ def main():
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
visualize_dataset(dataset, **vars(args))
visualize_dataset(dataset, **kwargs)
if __name__ == "__main__":