mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix(tests)
- Updated `lerobot_dataset.py:add_frame` to take task as key in frame - Updated `lerobot_dataset.py` to remove robot argument from `create` function of lerobotdataset and lerobotdatasetmetadata and directly take the features - Update `test_datasets.py` to features from Mock robot - Update all the usage of `add_frame` in the library - Update `dataset_factories.py`; had issues with new argument order - Raise ValueError when no task is provided (in `datasets/utils.py` validate func)
This commit is contained in:
@@ -51,10 +51,10 @@ while i < NB_CYCLES_CLIENT_CONNECTION:
|
||||
action_sent = robot.send_action(action)
|
||||
observation = robot.get_observation()
|
||||
|
||||
frame = {**action_sent, **observation}
|
||||
task = "Dummy Example Task Dataset"
|
||||
frame = {**action_sent, **observation, "task": task}
|
||||
|
||||
dataset.add_frame(frame, task)
|
||||
dataset.add_frame(frame)
|
||||
i += 1
|
||||
|
||||
print("Disconnecting Teleop Devices and LeKiwi Client")
|
||||
|
||||
@@ -218,8 +218,8 @@ def record_loop(
|
||||
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
frame = {**observation_frame, **action_frame}
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_data:
|
||||
for obs, val in observation.items():
|
||||
|
||||
@@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
|
||||
new_frame[key] = value
|
||||
|
||||
new_dataset.add_frame(new_frame, task=task)
|
||||
new_frame["task"] = task
|
||||
new_dataset.add_frame(new_frame)
|
||||
|
||||
if frame["episode_index"].item() != prev_episode_index:
|
||||
# Save the episode
|
||||
|
||||
@@ -2132,7 +2132,8 @@ def record_dataset(env, policy, cfg):
|
||||
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
||||
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
||||
)
|
||||
dataset.add_frame(frame, task=cfg.task)
|
||||
frame["task"] = cfg.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Maintain consistent timing
|
||||
if cfg.fps:
|
||||
|
||||
@@ -596,7 +596,8 @@ class ReplayBuffer:
|
||||
frame_dict[f"complementary_info.{key}"] = val
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict, task=task_name)
|
||||
frame_dict["task"] = task_name
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
|
||||
@@ -37,7 +37,9 @@ from lerobot.common.datasets.utils import (
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.common.robots import make_robot_from_config
|
||||
from lerobot.common.datasets.utils import hw_to_dataset_features
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
@@ -66,9 +68,12 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
objects have the same sets of attributes defined.
|
||||
"""
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||
|
||||
Reference in New Issue
Block a user