fix(tests)

- Updated `lerobot_dataset.py:add_frame` to take task as key in frame
- Updated `lerobot_dataset.py` to remove robot argument from `create` function of lerobotdataset and lerobotdatasetmetadata and directly take the features
- Update `test_datasets.py` to features from Mock robot
- Update all the usage of `add_frame` in the library
- Update `dataset_factories.py`; had issues with new argument order
- Raise ValueError when no task is provided (in `datasets/utils.py` validate func)
This commit is contained in:
Michel Aractingi
2025-07-01 16:06:48 +02:00
parent 0a1da47527
commit 5e39b4ce94
6 changed files with 18 additions and 10 deletions
+2 -2
View File
@@ -51,10 +51,10 @@ while i < NB_CYCLES_CLIENT_CONNECTION:
action_sent = robot.send_action(action) action_sent = robot.send_action(action)
observation = robot.get_observation() observation = robot.get_observation()
frame = {**action_sent, **observation}
task = "Dummy Example Task Dataset" task = "Dummy Example Task Dataset"
frame = {**action_sent, **observation, "task": task}
dataset.add_frame(frame, task) dataset.add_frame(frame)
i += 1 i += 1
print("Disconnecting Teleop Devices and LeKiwi Client") print("Disconnecting Teleop Devices and LeKiwi Client")
+2 -2
View File
@@ -218,8 +218,8 @@ def record_loop(
if dataset is not None: if dataset is not None:
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame} frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame, task=single_task) dataset.add_frame(frame)
if display_data: if display_data:
for obs, val in observation.items(): for obs, val in observation.items():
+2 -1
View File
@@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
new_frame[key] = value new_frame[key] = value
new_dataset.add_frame(new_frame, task=task) new_frame["task"] = task
new_dataset.add_frame(new_frame)
if frame["episode_index"].item() != prev_episode_index: if frame["episode_index"].item() != prev_episode_index:
# Save the episode # Save the episode
+2 -1
View File
@@ -2132,7 +2132,8 @@ def record_dataset(env, policy, cfg):
frame["complementary_info.discrete_penalty"] = torch.tensor( frame["complementary_info.discrete_penalty"] = torch.tensor(
[info.get("discrete_penalty", 0.0)], dtype=torch.float32 [info.get("discrete_penalty", 0.0)], dtype=torch.float32
) )
dataset.add_frame(frame, task=cfg.task) frame["task"] = cfg.task
dataset.add_frame(frame)
# Maintain consistent timing # Maintain consistent timing
if cfg.fps: if cfg.fps:
+2 -1
View File
@@ -596,7 +596,8 @@ class ReplayBuffer:
frame_dict[f"complementary_info.{key}"] = val frame_dict[f"complementary_info.{key}"] = val
# Add to the dataset's buffer # Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict, task=task_name) frame_dict["task"] = task_name
lerobot_dataset.add_frame(frame_dict)
# Move to next frame # Move to next frame
frame_idx_in_episode += 1 frame_idx_in_episode += 1
+8 -3
View File
@@ -37,7 +37,9 @@ from lerobot.common.datasets.utils import (
) )
from lerobot.common.envs.factory import make_env_config from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.robots import make_robot_from_config
from lerobot.common.datasets.utils import hw_to_dataset_features
from tests.mocks.mock_robot import MockRobotConfig
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
@@ -66,9 +68,12 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
objects have the same sets of attributes defined. objects have the same sets of attributes defined.
""" """
# Instantiate both ways # Instantiate both ways
robot = make_robot("koch", mock=True) robot = make_robot_from_config(MockRobotConfig())
action_features = hw_to_dataset_features(robot.action_features, "action", True)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
dataset_features = {**action_features, **obs_features}
root_create = tmp_path / "create" root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create)
root_init = tmp_path / "init" root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1) dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)