fix(tests)

- Updated `lerobot_dataset.py:add_frame` to take task as key in frame
- Updated `lerobot_dataset.py` to remove robot argument from `create` function of lerobotdataset and lerobotdatasetmetadata and directly take the features
- Update `test_datasets.py` to features from Mock robot
- Update all the usage of `add_frame` in the library
- Update `dataset_factories.py`; had issues with new argument order
- Raise ValueError when no task is provided (in `datasets/utils.py` validate func)
This commit is contained in:
Michel Aractingi
2025-07-01 16:06:48 +02:00
parent 67485b1edc
commit e43ece3271
10 changed files with 189 additions and 61 deletions
+2 -2
View File
@@ -51,10 +51,10 @@ while i < NB_CYCLES_CLIENT_CONNECTION:
action_sent = robot.send_action(action)
observation = robot.get_observation()
frame = {**action_sent, **observation}
task = "Dummy Example Task Dataset"
frame = {**action_sent, **observation, "task": task}
dataset.add_frame(frame, task)
dataset.add_frame(frame)
i += 1
print("Disconnecting Teleop Devices and LeKiwi Client")
+9 -36
View File
@@ -40,6 +40,7 @@ from lerobot.common.datasets.utils import (
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
INFO_PATH,
_validate_feature_names,
check_delta_timestamps,
check_version_compatibility,
concat_video_files,
@@ -79,7 +80,7 @@ from lerobot.common.datasets.video_utils import (
get_safe_default_codec,
get_video_info,
)
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robots.utils import Robot
CODEBASE_VERSION = "v3.0"
@@ -375,10 +376,9 @@ class LeRobotDatasetMetadata:
cls,
repo_id: str,
fps: int,
root: str | Path | None = None,
robot: Robot | None = None,
features: dict,
robot_type: str | None = None,
features: dict | None = None,
root: str | Path | None = None,
use_videos: bool = True,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
@@ -388,34 +388,13 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
)
elif features is None:
raise ValueError(
"Dataset features must either come from a Robot or explicitly passed upon creation."
)
else:
# TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}
# check if none of the features contains a "/" in their names,
# as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features:
if "/" in key:
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
features = {**features, **DEFAULT_FEATURES}
_validate_feature_names(features)
obj.tasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
@@ -867,14 +846,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
# Add frame features to episode_buffer
for key in frame:
if key == "task":
# Note: we associate the task in natural language to its task index during `save_episode`
self.episode_buffer["task"].append(frame["task"])
continue
if key not in self.features:
raise ValueError(
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
@@ -1132,10 +1107,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
cls,
repo_id: str,
fps: int,
features: dict,
root: str | Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
features: dict | None = None,
use_videos: bool = True,
tolerance_s: float = 1e-4,
image_writer_processes: int = 0,
@@ -1147,10 +1121,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=fps,
root=root,
robot=robot,
robot_type=robot_type,
features=features,
root=root,
use_videos=use_videos,
)
obj.repo_id = obj.meta.repo_id
+10 -3
View File
@@ -834,10 +834,17 @@ def validate_frame(frame: dict, features: dict):
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
error_message = validate_features_presence(actual_features, expected_features)
# task is a special required field that's not part of regular features
if "task" not in actual_features:
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
common_features = actual_features & expected_features
for name in common_features - {"task"}:
# Remove task from actual_features for regular feature validation
actual_features_for_validation = actual_features - {"task"}
error_message = validate_features_presence(actual_features_for_validation, expected_features)
common_features = actual_features_for_validation & expected_features
for name in common_features:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
+2 -1
View File
@@ -596,7 +596,8 @@ class ReplayBuffer:
frame_dict[f"complementary_info.{key}"] = val
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict, task=task_name)
frame_dict["task"] = task_name
lerobot_dataset.add_frame(frame_dict)
# Move to next frame
frame_idx_in_episode += 1
+148 -4
View File
@@ -17,10 +17,14 @@ import logging
import os
import os.path as osp
import platform
import select
import subprocess
from copy import copy
import sys
import time
from copy import copy, deepcopy
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
import numpy as np
import torch
@@ -107,10 +111,16 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.")
def init_logging():
def init_logging(log_file: Path | None = None, display_pid: bool = False):
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
# NOTE: Display PID is useful for multi-process logging.
if display_pid:
pid_str = f"[PID: {os.getpid()}]"
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
else:
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
@@ -125,6 +135,12 @@ def init_logging():
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
if log_file is not None:
# Additionally write logs to file
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
@@ -168,7 +184,7 @@ def capture_timestamp_utc():
return datetime.now(timezone.utc)
def say(text, blocking=False):
def say(text: str, blocking: bool = False):
system = platform.system()
if system == "Darwin":
@@ -196,7 +212,7 @@ def say(text, blocking=False):
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
def log_say(text, play_sounds, blocking=False):
def log_say(text: str, play_sounds: bool = True, blocking: bool = False):
logging.info(text)
if play_sounds:
@@ -230,6 +246,23 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
return False
def enter_pressed() -> bool:
if platform.system() == "Windows":
import msvcrt
if msvcrt.kbhit():
key = msvcrt.getch()
return key in (b"\r", b"\n") # enter key
return False
else:
return select.select([sys.stdin], [], [], 0)[0] and sys.stdin.readline().strip() == ""
def move_cursor_up(lines):
"""Move the cursor up by a specified number of lines."""
print(f"\033[{lines}A", end="")
def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
days = int(elapsed_time_s // (24 * 3600))
elapsed_time_s %= 24 * 3600
@@ -238,3 +271,114 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
minutes = int(elapsed_time_s // 60)
seconds = elapsed_time_s % 60
return days, hours, minutes, seconds
class TimerManager:
"""
Lightweight utility to measure elapsed time.
Examples
--------
```python
# Example 1: Using context manager
timer = TimerManager("Policy", log=False)
for _ in range(3):
with timer:
time.sleep(0.01)
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
```
```python
# Example 2: Using start/stop methods
timer = TimerManager("Policy", log=False)
timer.start()
time.sleep(0.01)
timer.stop()
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
```
"""
def __init__(
self,
label: str = "Elapsed-time",
log: bool = True,
logger: logging.Logger | None = None,
):
self.label = label
self.log = log
self.logger = logger
self._start: float | None = None
self._history: list[float] = []
def __enter__(self):
return self.start()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def start(self):
self._start = time.perf_counter()
return self
def stop(self) -> float:
if self._start is None:
raise RuntimeError("Timer was never started.")
elapsed = time.perf_counter() - self._start
self._history.append(elapsed)
self._start = None
if self.log:
if self.logger is not None:
self.logger.info(f"{self.label}: {elapsed:.6f} s")
else:
logging.info(f"{self.label}: {elapsed:.6f} s")
return elapsed
def reset(self):
self._history.clear()
@property
def last(self) -> float:
return self._history[-1] if self._history else 0.0
@property
def avg(self) -> float:
return mean(self._history) if self._history else 0.0
@property
def total(self) -> float:
return sum(self._history)
@property
def count(self) -> int:
return len(self._history)
@property
def history(self) -> list[float]:
return deepcopy(self._history)
@property
def fps_history(self) -> list[float]:
return [1.0 / t for t in self._history]
@property
def fps_last(self) -> float:
return 0.0 if self.last == 0 else 1.0 / self.last
@property
def fps_avg(self) -> float:
return 0.0 if self.avg == 0 else 1.0 / self.avg
def percentile(self, p: float) -> float:
"""
Return the p-th percentile of recorded times.
"""
if not self._history:
return 0.0
return float(np.percentile(self._history, p))
def fps_percentile(self, p: float) -> float:
"""
FPS corresponding to the p-th percentile time.
"""
val = self.percentile(p)
return 0.0 if val == 0 else 1.0 / val
+2 -2
View File
@@ -218,8 +218,8 @@ def record_loop(
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)
if display_data:
for obs, val in observation.items():
+2 -1
View File
@@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
new_frame[key] = value
new_dataset.add_frame(new_frame, task=task)
new_frame["task"] = task
new_dataset.add_frame(new_frame)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode
+2 -1
View File
@@ -2132,7 +2132,8 @@ def record_dataset(env, policy, cfg):
frame["complementary_info.discrete_penalty"] = torch.tensor(
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
)
dataset.add_frame(frame, task=cfg.task)
frame["task"] = cfg.task
dataset.add_frame(frame)
# Maintain consistent timing
if cfg.fps:
+8 -3
View File
@@ -37,7 +37,9 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.robots import make_robot_from_config
from lerobot.common.datasets.utils import hw_to_dataset_features
from tests.mocks.mock_robot import MockRobotConfig
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
@@ -66,9 +68,12 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
objects have the same sets of attributes defined.
"""
# Instantiate both ways
robot = make_robot("koch", mock=True)
robot = make_robot_from_config(MockRobotConfig())
action_features = hw_to_dataset_features(robot.action_features, "action", True)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
dataset_features = {**action_features, **obs_features}
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
+2 -6
View File
@@ -474,9 +474,7 @@ def lerobot_dataset_metadata_factory(
)
with (
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
):
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
@@ -558,9 +556,7 @@ def lerobot_dataset_factory(
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
):
mock_metadata_patch.return_value = mock_metadata
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version