From 5e39b4ce943cc9ae5562ca3acfb2c32da99b13d7 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 1 Jul 2025 16:06:48 +0200 Subject: [PATCH] fix(tests) - Updated `lerobot_dataset.py:add_frame` to take task as key in frame - Updated `lerobot_dataset.py` to remove robot argument from `create` function of lerobotdataset and lerobotdatasetmetadata and directly take the features - Update `test_datasets.py` to features from Mock robot - Update all the usage of `add_frame` in the library - Update `dataset_factories.py`; had issues with new argument order - Raise ValueError when no task is provided (in `datasets/utils.py` validate func) --- examples/lekiwi/record.py | 4 ++-- src/lerobot/record.py | 4 ++-- src/lerobot/scripts/rl/crop_dataset_roi.py | 3 ++- src/lerobot/scripts/rl/gym_manipulator.py | 3 ++- src/lerobot/utils/buffer.py | 3 ++- tests/datasets/test_datasets.py | 11 ++++++++--- 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 2ad32677f..e6b774f19 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -51,10 +51,10 @@ while i < NB_CYCLES_CLIENT_CONNECTION: action_sent = robot.send_action(action) observation = robot.get_observation() - frame = {**action_sent, **observation} task = "Dummy Example Task Dataset" + frame = {**action_sent, **observation, "task": task} - dataset.add_frame(frame, task) + dataset.add_frame(frame) i += 1 print("Disconnecting Teleop Devices and LeKiwi Client") diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 54d7f3952..5d4638cce 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -218,8 +218,8 @@ def record_loop( if dataset is not None: action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") - frame = {**observation_frame, **action_frame} - dataset.add_frame(frame, task=single_task) + frame = {**observation_frame, **action_frame, "task": single_task} + dataset.add_frame(frame) if display_data: for obs, val in observation.items(): diff --git a/src/lerobot/scripts/rl/crop_dataset_roi.py b/src/lerobot/scripts/rl/crop_dataset_roi.py index 4cb7a3e8a..3d57ddb99 100644 --- a/src/lerobot/scripts/rl/crop_dataset_roi.py +++ b/src/lerobot/scripts/rl/crop_dataset_roi.py @@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( new_frame[key] = value - new_dataset.add_frame(new_frame, task=task) + new_frame["task"] = task + new_dataset.add_frame(new_frame) if frame["episode_index"].item() != prev_episode_index: # Save the episode diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 76a136084..41f75886d 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -2132,7 +2132,8 @@ def record_dataset(env, policy, cfg): frame["complementary_info.discrete_penalty"] = torch.tensor( [info.get("discrete_penalty", 0.0)], dtype=torch.float32 ) - dataset.add_frame(frame, task=cfg.task) + frame["task"] = cfg.task + dataset.add_frame(frame) # Maintain consistent timing if cfg.fps: diff --git a/src/lerobot/utils/buffer.py b/src/lerobot/utils/buffer.py index 7f8d989dd..e276ef453 100644 --- a/src/lerobot/utils/buffer.py +++ b/src/lerobot/utils/buffer.py @@ -596,7 +596,8 @@ class ReplayBuffer: frame_dict[f"complementary_info.{key}"] = val # Add to the dataset's buffer - lerobot_dataset.add_frame(frame_dict, task=task_name) + frame_dict["task"] = task_name + lerobot_dataset.add_frame(frame_dict) # Move to next frame frame_idx_in_episode += 1 diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 1557c3b7a..7d390c243 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -37,7 +37,9 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.envs.factory import make_env_config from lerobot.common.policies.factory import make_policy_config -from lerobot.common.robot_devices.robots.utils import make_robot +from lerobot.common.robots import make_robot_from_config +from lerobot.common.datasets.utils import hw_to_dataset_features +from tests.mocks.mock_robot import MockRobotConfig from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID @@ -66,9 +68,12 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): objects have the same sets of attributes defined. """ # Instantiate both ways - robot = make_robot("koch", mock=True) + robot = make_robot_from_config(MockRobotConfig()) + action_features = hw_to_dataset_features(robot.action_features, "action", True) + obs_features = hw_to_dataset_features(robot.observation_features, "observation", True) + dataset_features = {**action_features, **obs_features} root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) + dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create) root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)