mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
feat(blueprints): switching to blueprints for backwards (and forward) compatibiltiy
This commit is contained in:
@@ -92,8 +92,7 @@ def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]:
|
||||
|
||||
|
||||
def get_sequential_colors(num_dims: int) -> list[list[int]]:
|
||||
"""Return a deterministic list of distinct RGB colors, one per dimension.
|
||||
"""
|
||||
"""Return a deterministic list of distinct RGB colors, one per dimension."""
|
||||
colors = []
|
||||
for d in range(num_dims):
|
||||
hue = d / max(num_dims, 1)
|
||||
@@ -111,6 +110,41 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
return hwc_uint8_numpy
|
||||
|
||||
|
||||
def build_blueprint_from_dataset(dataset: LeRobotDataset):
|
||||
"""Build a Rerun blueprint laying out camera images and time series for the given dataset.
|
||||
|
||||
Camera images are arranged in a grid on the left, and the available scalar signals
|
||||
(action, state, reward, done, success) are stacked as time series views on the right.
|
||||
The per-dimension series names and colors for ``action`` and ``state`` are applied
|
||||
directly via blueprint overrides.
|
||||
"""
|
||||
import rerun as rr
|
||||
import rerun.blueprint as rrb
|
||||
|
||||
image_views = [rrb.Spatial2DView(origin=key, name=key) for key in dataset.meta.camera_keys]
|
||||
|
||||
timeseries_views = []
|
||||
# Style multi-dimensional signals (action, state) with per-dimension names and colors.
|
||||
for origin, key in ((ACTION, ACTION), ("state", OBS_STATE)):
|
||||
if key in dataset.features:
|
||||
names = get_feature_names(dataset, key)
|
||||
styling = rr.SeriesLines(names=names, colors=get_sequential_colors(len(names)))
|
||||
timeseries_views.append(
|
||||
rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling})
|
||||
)
|
||||
for key in (DONE, REWARD, "next.success"):
|
||||
if key in dataset.features:
|
||||
timeseries_views.append(rrb.TimeSeriesView(origin=key, name=key))
|
||||
|
||||
contents = []
|
||||
if image_views:
|
||||
contents.append(rrb.Grid(*image_views, name="images"))
|
||||
if timeseries_views:
|
||||
contents.append(rrb.Vertical(*timeseries_views, name="time series"))
|
||||
|
||||
return rrb.Blueprint(rrb.Horizontal(*contents) if contents else rrb.Grid())
|
||||
|
||||
|
||||
def visualize_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
@@ -149,7 +183,8 @@ def visualize_dataset(
|
||||
import rerun as rr
|
||||
|
||||
spawn_local_viewer = mode == "local" and not save
|
||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
|
||||
blueprint = build_blueprint_from_dataset(dataset)
|
||||
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer, default_blueprint=blueprint)
|
||||
|
||||
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
|
||||
# when iterating on a dataloader with `num_workers` > 0
|
||||
@@ -163,36 +198,23 @@ def visualize_dataset(
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
# Name each series once (static) so all dimensions share a single view while keeping labels.
|
||||
# Labels come straight from the dataset metadata.
|
||||
if ACTION in dataset.features:
|
||||
names = get_feature_names(dataset, ACTION)
|
||||
rr.log(ACTION, rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True)
|
||||
if OBS_STATE in dataset.features:
|
||||
names = get_feature_names(dataset, OBS_STATE)
|
||||
rr.log("state", rr.SeriesLines(names=names, colors=get_sequential_colors(len(names))), static=True)
|
||||
|
||||
first_index = None
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
if first_index is None:
|
||||
first_index = batch["index"][0].item()
|
||||
|
||||
# iterate over the batch
|
||||
for i in range(len(batch["index"])):
|
||||
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
|
||||
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
|
||||
|
||||
# display each camera image
|
||||
for key in dataset.meta.camera_keys:
|
||||
img = to_hwc_uint8_numpy(batch[key][i])
|
||||
img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img)
|
||||
rr.log(key, entity=img_entity)
|
||||
|
||||
# display each dimension of action space (e.g. actuators command)
|
||||
if ACTION in batch:
|
||||
rr.log(ACTION, rr.Scalars(batch[ACTION][i].numpy()))
|
||||
|
||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||
if OBS_STATE in batch:
|
||||
rr.log("state", rr.Scalars(batch[OBS_STATE][i].numpy()))
|
||||
|
||||
@@ -206,8 +228,6 @@ def visualize_dataset(
|
||||
rr.log("next.success", rr.Scalars(batch["next.success"][i].item()))
|
||||
|
||||
if mode == "local" and save:
|
||||
# save .rrd locally
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
repo_id_str = repo_id.replace("/", "_")
|
||||
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
|
||||
@@ -215,7 +235,7 @@ def visualize_dataset(
|
||||
return rrd_path
|
||||
|
||||
elif mode == "distant":
|
||||
# stop the process from exiting since it is serving the websocket connection
|
||||
# Keep the process alive while it serves the gRPC/web connection.
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
@@ -335,7 +355,7 @@ def main():
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
visualize_dataset(dataset, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user