diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 405a41bd3..61100a231 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -51,10 +51,10 @@ while i < NB_CYCLES_CLIENT_CONNECTION: action_sent = robot.send_action(action) observation = robot.get_observation() - frame = {**action_sent, **observation} task = "Dummy Example Task Dataset" + frame = {**action_sent, **observation, "task": task} - dataset.add_frame(frame, task) + dataset.add_frame(frame) i += 1 print("Disconnecting Teleop Devices and LeKiwi Client") diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index e7b71827c..79a89d6ec 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -40,6 +40,7 @@ from lerobot.common.datasets.utils import ( DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, INFO_PATH, + _validate_feature_names, check_delta_timestamps, check_version_compatibility, concat_video_files, @@ -79,7 +80,7 @@ from lerobot.common.datasets.video_utils import ( get_safe_default_codec, get_video_info, ) -from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robots.utils import Robot CODEBASE_VERSION = "v3.0" @@ -375,10 +376,9 @@ class LeRobotDatasetMetadata: cls, repo_id: str, fps: int, - root: str | Path | None = None, - robot: Robot | None = None, + features: dict, robot_type: str | None = None, - features: dict | None = None, + root: str | Path | None = None, use_videos: bool = True, ) -> "LeRobotDatasetMetadata": """Creates metadata for a LeRobotDataset.""" @@ -388,34 +388,13 @@ class LeRobotDatasetMetadata: obj.root.mkdir(parents=True, exist_ok=False) - if robot is not None: - features = get_features_from_robot(robot, use_videos) - robot_type = robot.robot_type - if not all(cam.fps == fps for cam in robot.cameras.values()): - logging.warning( - f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." - "In this case, frames from lower fps cameras will be repeated to fill in the blanks." - ) - elif features is None: - raise ValueError( - "Dataset features must either come from a Robot or explicitly passed upon creation." - ) - else: - # TODO(aliberts, rcadene): implement sanity check for features - features = {**features, **DEFAULT_FEATURES} - - # check if none of the features contains a "/" in their names, - # as this would break the dict flattening in the stats computation, which uses '/' as separator - for key in features: - if "/" in key: - raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.") - - features = {**features, **DEFAULT_FEATURES} + features = {**features, **DEFAULT_FEATURES} + _validate_feature_names(features) obj.tasks = None obj.episodes = None obj.stats = None - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) + obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) @@ -867,14 +846,10 @@ class LeRobotDataset(torch.utils.data.Dataset): timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) + self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing # Add frame features to episode_buffer for key in frame: - if key == "task": - # Note: we associate the task in natural language to its task index during `save_episode` - self.episode_buffer["task"].append(frame["task"]) - continue - if key not in self.features: raise ValueError( f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." @@ -1132,10 +1107,9 @@ class LeRobotDataset(torch.utils.data.Dataset): cls, repo_id: str, fps: int, + features: dict, root: str | Path | None = None, - robot: Robot | None = None, robot_type: str | None = None, - features: dict | None = None, use_videos: bool = True, tolerance_s: float = 1e-4, image_writer_processes: int = 0, @@ -1147,10 +1121,9 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.meta = LeRobotDatasetMetadata.create( repo_id=repo_id, fps=fps, - root=root, - robot=robot, robot_type=robot_type, features=features, + root=root, use_videos=use_videos, ) obj.repo_id = obj.meta.repo_id diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index d5dc702dc..deebc1ac3 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -834,10 +834,17 @@ def validate_frame(frame: dict, features: dict): expected_features = set(features) - set(DEFAULT_FEATURES) actual_features = set(frame) - error_message = validate_features_presence(actual_features, expected_features) + # task is a special required field that's not part of regular features + if "task" not in actual_features: + raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") - common_features = actual_features & expected_features - for name in common_features - {"task"}: + # Remove task from actual_features for regular feature validation + actual_features_for_validation = actual_features - {"task"} + + error_message = validate_features_presence(actual_features_for_validation, expected_features) + + common_features = actual_features_for_validation & expected_features + for name in common_features: error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) if error_message: diff --git a/lerobot/common/utils/buffer.py b/lerobot/common/utils/buffer.py index 9ae231ad9..eab1d24ed 100644 --- a/lerobot/common/utils/buffer.py +++ b/lerobot/common/utils/buffer.py @@ -596,7 +596,8 @@ class ReplayBuffer: frame_dict[f"complementary_info.{key}"] = val # Add to the dataset's buffer - lerobot_dataset.add_frame(frame_dict, task=task_name) + frame_dict["task"] = task_name + lerobot_dataset.add_frame(frame_dict) # Move to next frame frame_idx_in_episode += 1 diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index bcb4d6bc5..b7c104cf6 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -17,10 +17,14 @@ import logging import os import os.path as osp import platform +import select import subprocess -from copy import copy +import sys +import time +from copy import copy, deepcopy from datetime import datetime, timezone from pathlib import Path +from statistics import mean import numpy as np import torch @@ -107,11 +111,17 @@ def is_amp_available(device: str): raise ValueError(f"Unknown device '{device}.") -def init_logging(): +def init_logging(log_file: Path | None = None, display_pid: bool = False): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + + # NOTE: Display PID is useful for multi-process logging. + if display_pid: + pid_str = f"[PID: {os.getpid()}]" + message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}" + else: + message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) @@ -125,6 +135,12 @@ def init_logging(): console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) + if log_file is not None: + # Additionally write logs to file + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logging.getLogger().addHandler(file_handler) + def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] @@ -168,7 +184,7 @@ def capture_timestamp_utc(): return datetime.now(timezone.utc) -def say(text, blocking=False): +def say(text: str, blocking: bool = False): system = platform.system() if system == "Darwin": @@ -196,7 +212,7 @@ def say(text, blocking=False): subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0) -def log_say(text, play_sounds, blocking=False): +def log_say(text: str, play_sounds: bool = True, blocking: bool = False): logging.info(text) if play_sounds: @@ -230,6 +246,23 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool: return False +def enter_pressed() -> bool: + if platform.system() == "Windows": + import msvcrt + + if msvcrt.kbhit(): + key = msvcrt.getch() + return key in (b"\r", b"\n") # enter key + return False + else: + return select.select([sys.stdin], [], [], 0)[0] and sys.stdin.readline().strip() == "" + + +def move_cursor_up(lines): + """Move the cursor up by a specified number of lines.""" + print(f"\033[{lines}A", end="") + + def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): days = int(elapsed_time_s // (24 * 3600)) elapsed_time_s %= 24 * 3600 @@ -238,3 +271,114 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): minutes = int(elapsed_time_s // 60) seconds = elapsed_time_s % 60 return days, hours, minutes, seconds + + +class TimerManager: + """ + Lightweight utility to measure elapsed time. + + Examples + -------- + ```python + # Example 1: Using context manager + timer = TimerManager("Policy", log=False) + for _ in range(3): + with timer: + time.sleep(0.01) + print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01 + ``` + + ```python + # Example 2: Using start/stop methods + timer = TimerManager("Policy", log=False) + timer.start() + time.sleep(0.01) + timer.stop() + print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01 + ``` + """ + + def __init__( + self, + label: str = "Elapsed-time", + log: bool = True, + logger: logging.Logger | None = None, + ): + self.label = label + self.log = log + self.logger = logger + self._start: float | None = None + self._history: list[float] = [] + + def __enter__(self): + return self.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def start(self): + self._start = time.perf_counter() + return self + + def stop(self) -> float: + if self._start is None: + raise RuntimeError("Timer was never started.") + elapsed = time.perf_counter() - self._start + self._history.append(elapsed) + self._start = None + if self.log: + if self.logger is not None: + self.logger.info(f"{self.label}: {elapsed:.6f} s") + else: + logging.info(f"{self.label}: {elapsed:.6f} s") + return elapsed + + def reset(self): + self._history.clear() + + @property + def last(self) -> float: + return self._history[-1] if self._history else 0.0 + + @property + def avg(self) -> float: + return mean(self._history) if self._history else 0.0 + + @property + def total(self) -> float: + return sum(self._history) + + @property + def count(self) -> int: + return len(self._history) + + @property + def history(self) -> list[float]: + return deepcopy(self._history) + + @property + def fps_history(self) -> list[float]: + return [1.0 / t for t in self._history] + + @property + def fps_last(self) -> float: + return 0.0 if self.last == 0 else 1.0 / self.last + + @property + def fps_avg(self) -> float: + return 0.0 if self.avg == 0 else 1.0 / self.avg + + def percentile(self, p: float) -> float: + """ + Return the p-th percentile of recorded times. + """ + if not self._history: + return 0.0 + return float(np.percentile(self._history, p)) + + def fps_percentile(self, p: float) -> float: + """ + FPS corresponding to the p-th percentile time. + """ + val = self.percentile(p) + return 0.0 if val == 0 else 1.0 / val diff --git a/lerobot/record.py b/lerobot/record.py index ce6f538d5..b2aae07e0 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -218,8 +218,8 @@ def record_loop( if dataset is not None: action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") - frame = {**observation_frame, **action_frame} - dataset.add_frame(frame, task=single_task) + frame = {**observation_frame, **action_frame, "task": single_task} + dataset.add_frame(frame) if display_data: for obs, val in observation.items(): diff --git a/lerobot/scripts/rl/crop_dataset_roi.py b/lerobot/scripts/rl/crop_dataset_roi.py index 5b7038de3..4c53bc522 100644 --- a/lerobot/scripts/rl/crop_dataset_roi.py +++ b/lerobot/scripts/rl/crop_dataset_roi.py @@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( new_frame[key] = value - new_dataset.add_frame(new_frame, task=task) + new_frame["task"] = task + new_dataset.add_frame(new_frame) if frame["episode_index"].item() != prev_episode_index: # Save the episode diff --git a/lerobot/scripts/rl/gym_manipulator.py b/lerobot/scripts/rl/gym_manipulator.py index e7327d96d..4f0000beb 100644 --- a/lerobot/scripts/rl/gym_manipulator.py +++ b/lerobot/scripts/rl/gym_manipulator.py @@ -2132,7 +2132,8 @@ def record_dataset(env, policy, cfg): frame["complementary_info.discrete_penalty"] = torch.tensor( [info.get("discrete_penalty", 0.0)], dtype=torch.float32 ) - dataset.add_frame(frame, task=cfg.task) + frame["task"] = cfg.task + dataset.add_frame(frame) # Maintain consistent timing if cfg.fps: diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 1557c3b7a..7d390c243 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -37,7 +37,9 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.envs.factory import make_env_config from lerobot.common.policies.factory import make_policy_config -from lerobot.common.robot_devices.robots.utils import make_robot +from lerobot.common.robots import make_robot_from_config +from lerobot.common.datasets.utils import hw_to_dataset_features +from tests.mocks.mock_robot import MockRobotConfig from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID @@ -66,9 +68,12 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): objects have the same sets of attributes defined. """ # Instantiate both ways - robot = make_robot("koch", mock=True) + robot = make_robot_from_config(MockRobotConfig()) + action_features = hw_to_dataset_features(robot.action_features, "action", True) + obs_features = hw_to_dataset_features(robot.observation_features, "observation", True) + dataset_features = {**action_features, **obs_features} root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) + dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create) root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 7e98c51f7..b36111c02 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -474,9 +474,7 @@ def lerobot_dataset_metadata_factory( ) with ( patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, - patch( - "lerobot.common.datasets.lerobot_dataset.snapshot_download" - ) as mock_snapshot_download_patch, + patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch, ): mock_get_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download @@ -558,9 +556,7 @@ def lerobot_dataset_factory( with ( patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, - patch( - "lerobot.common.datasets.lerobot_dataset.snapshot_download" - ) as mock_snapshot_download_patch, + patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch, ): mock_metadata_patch.return_value = mock_metadata mock_get_safe_version_patch.side_effect = lambda repo_id, version: version