mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
167 lines
6.2 KiB
Python
167 lines
6.2 KiB
Python
import torch
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import rerun as rr
|
|
import rerun.blueprint as rrb
|
|
|
|
|
|
class RerunLogger:
|
|
"""
|
|
A fully automatic Rerun logger designed to parse and visualize step
|
|
dictionaries directly from a LeRobotDataset.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
prefix: str = "",
|
|
memory_limit: str = "200MB",
|
|
idxrangeboundary: int | None = 300,
|
|
):
|
|
"""Initializes the Rerun logger."""
|
|
# Use a descriptive name for the Rerun recording
|
|
rr.init(f"Dataset_Log_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
rr.spawn(memory_limit=memory_limit)
|
|
|
|
self.prefix = prefix
|
|
self.blueprint_sent = False
|
|
self.idxrangeboundary = idxrangeboundary
|
|
|
|
# --- Internal cache for discovered keys ---
|
|
self._image_keys: tuple[str, ...] = ()
|
|
self._state_key: str = ""
|
|
self._action_key: str = ""
|
|
self._index_key: str = "index"
|
|
self._task_key: str = "task"
|
|
self._episode_index_key: str = "episode_index"
|
|
|
|
self.current_episode = -1
|
|
|
|
def _initialize_from_data(self, step_data: dict[str, Any]):
|
|
"""Inspects the first data dictionary to discover components and set up the blueprint."""
|
|
print("RerunLogger: First data packet received. Auto-configuring...")
|
|
|
|
image_keys = []
|
|
for key, value in step_data.items():
|
|
if key.startswith("observation.images.") and isinstance(value, torch.Tensor) and value.ndim > 2:
|
|
image_keys.append(key)
|
|
elif key == "observation.state":
|
|
self._state_key = key
|
|
elif key == "action":
|
|
self._action_key = key
|
|
|
|
self._image_keys = tuple(sorted(image_keys))
|
|
|
|
if "index" in step_data:
|
|
self._index_key = "index"
|
|
elif "frame_index" in step_data:
|
|
self._index_key = "frame_index"
|
|
|
|
print(f" - Using '{self._index_key}' for time sequence.")
|
|
print(f" - Detected State Key: '{self._state_key}'")
|
|
print(f" - Detected Action Key: '{self._action_key}'")
|
|
print(f" - Detected Image Keys: {self._image_keys}")
|
|
if self.idxrangeboundary:
|
|
self.setup_blueprint()
|
|
|
|
def setup_blueprint(self):
|
|
"""Sets up and sends the Rerun blueprint based on detected components."""
|
|
views = []
|
|
|
|
for key in self._image_keys:
|
|
clean_name = key.replace("observation.images.", "")
|
|
entity_path = f"{self.prefix}images/{clean_name}"
|
|
views.append(rrb.Spatial2DView(origin=entity_path, name=clean_name))
|
|
|
|
if self._state_key:
|
|
entity_path = f"{self.prefix}state"
|
|
views.append(
|
|
rrb.TimeSeriesView(
|
|
origin=entity_path,
|
|
name="Observation State",
|
|
time_ranges=[
|
|
rrb.VisibleTimeRange(
|
|
"frame",
|
|
start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary),
|
|
end=rrb.TimeRangeBoundary.cursor_relative(),
|
|
)
|
|
],
|
|
plot_legend=rrb.PlotLegend(visible=True),
|
|
)
|
|
)
|
|
|
|
if self._action_key:
|
|
entity_path = f"{self.prefix}action"
|
|
views.append(
|
|
rrb.TimeSeriesView(
|
|
origin=entity_path,
|
|
name="Action",
|
|
time_ranges=[
|
|
rrb.VisibleTimeRange(
|
|
"frame",
|
|
start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary),
|
|
end=rrb.TimeRangeBoundary.cursor_relative(),
|
|
)
|
|
],
|
|
plot_legend=rrb.PlotLegend(visible=True),
|
|
)
|
|
)
|
|
|
|
if not views:
|
|
print("Warning: No visualizable components detected in the data.")
|
|
return
|
|
|
|
grid = rrb.Grid(contents=views)
|
|
rr.send_blueprint(grid)
|
|
self.blueprint_sent = True
|
|
|
|
def log_step(self, step_data: dict[str, Any]):
|
|
"""Logs a single step dictionary from your dataset."""
|
|
if not self.blueprint_sent:
|
|
self._initialize_from_data(step_data)
|
|
|
|
if self._index_key in step_data:
|
|
current_index = step_data[self._index_key].item()
|
|
rr.set_time_sequence("frame", current_index)
|
|
|
|
episode_idx = step_data.get(self._episode_index_key, torch.tensor(-1)).item()
|
|
if episode_idx != self.current_episode:
|
|
self.current_episode = episode_idx
|
|
task_name = step_data.get(self._task_key, "Unknown Task")
|
|
log_text = f"Starting Episode {self.current_episode}: {task_name}"
|
|
rr.log(f"{self.prefix}info/task", rr.TextLog(log_text, level=rr.TextLogLevel.INFO))
|
|
|
|
for key in self._image_keys:
|
|
if key in step_data:
|
|
image_tensor = step_data[key]
|
|
if image_tensor.ndim > 2:
|
|
clean_name = key.replace("observation.images.", "")
|
|
entity_path = f"{self.prefix}images/{clean_name}"
|
|
if image_tensor.shape[0] in [1, 3, 4]:
|
|
image_tensor = image_tensor.permute(1, 2, 0)
|
|
rr.log(entity_path, rr.Image(image_tensor))
|
|
|
|
if self._state_key in step_data:
|
|
state_tensor = step_data[self._state_key]
|
|
entity_path = f"{self.prefix}state"
|
|
for i, val in enumerate(state_tensor):
|
|
rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item()))
|
|
|
|
if self._action_key in step_data:
|
|
action_tensor = step_data[self._action_key]
|
|
entity_path = f"{self.prefix}action"
|
|
for i, val in enumerate(action_tensor):
|
|
rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item()))
|
|
|
|
|
|
def visualization_data(idx, observation, state, action, online_logger):
|
|
item_data: dict[str, Any] = {
|
|
"index": torch.tensor(idx),
|
|
"observation.state": state,
|
|
"action": action,
|
|
}
|
|
for k, v in observation.items():
|
|
if k not in ("index", "observation.state", "action"):
|
|
item_data[k] = v
|
|
online_logger.log_step(item_data)
|