fix(tests)

- Updated `lerobot_dataset.py:add_frame` to take task as key in frame
- Updated `lerobot_dataset.py` to remove robot argument from `create` function of lerobotdataset and lerobotdatasetmetadata and directly take the features
- Update `test_datasets.py` to features from Mock robot
- Update all the usage of `add_frame` in the library
- Update `dataset_factories.py`; had issues with new argument order
- Raise ValueError when no task is provided (in `datasets/utils.py` validate func)
This commit is contained in:
Michel Aractingi
2025-07-01 16:06:48 +02:00
parent 0a1da47527
commit 5e39b4ce94
6 changed files with 18 additions and 10 deletions
+2 -2
View File
@@ -51,10 +51,10 @@ while i < NB_CYCLES_CLIENT_CONNECTION:
action_sent = robot.send_action(action)
observation = robot.get_observation()
frame = {**action_sent, **observation}
task = "Dummy Example Task Dataset"
frame = {**action_sent, **observation, "task": task}
dataset.add_frame(frame, task)
dataset.add_frame(frame)
i += 1
print("Disconnecting Teleop Devices and LeKiwi Client")
+2 -2
View File
@@ -218,8 +218,8 @@ def record_loop(
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)
if display_data:
for obs, val in observation.items():
+2 -1
View File
@@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
new_frame[key] = value
new_dataset.add_frame(new_frame, task=task)
new_frame["task"] = task
new_dataset.add_frame(new_frame)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode
+2 -1
View File
@@ -2132,7 +2132,8 @@ def record_dataset(env, policy, cfg):
frame["complementary_info.discrete_penalty"] = torch.tensor(
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
)
dataset.add_frame(frame, task=cfg.task)
frame["task"] = cfg.task
dataset.add_frame(frame)
# Maintain consistent timing
if cfg.fps:
+2 -1
View File
@@ -596,7 +596,8 @@ class ReplayBuffer:
frame_dict[f"complementary_info.{key}"] = val
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict, task=task_name)
frame_dict["task"] = task_name
lerobot_dataset.add_frame(frame_dict)
# Move to next frame
frame_idx_in_episode += 1
+8 -3
View File
@@ -37,7 +37,9 @@ from lerobot.common.datasets.utils import (
)
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.robots import make_robot_from_config
from lerobot.common.datasets.utils import hw_to_dataset_features
from tests.mocks.mock_robot import MockRobotConfig
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
@@ -66,9 +68,12 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
objects have the same sets of attributes defined.
"""
# Instantiate both ways
robot = make_robot("koch", mock=True)
robot = make_robot_from_config(MockRobotConfig())
action_features = hw_to_dataset_features(robot.action_features, "action", True)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
dataset_features = {**action_features, **obs_features}
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)