mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
fix(tests)
- Updated `lerobot_dataset.py:add_frame` to take task as key in frame - Updated `lerobot_dataset.py` to remove robot argument from `create` function of lerobotdataset and lerobotdatasetmetadata and directly take the features - Update `test_datasets.py` to features from Mock robot - Update all the usage of `add_frame` in the library - Update `dataset_factories.py`; had issues with new argument order - Raise ValueError when no task is provided (in `datasets/utils.py` validate func)
This commit is contained in:
@@ -51,10 +51,10 @@ while i < NB_CYCLES_CLIENT_CONNECTION:
|
|||||||
action_sent = robot.send_action(action)
|
action_sent = robot.send_action(action)
|
||||||
observation = robot.get_observation()
|
observation = robot.get_observation()
|
||||||
|
|
||||||
frame = {**action_sent, **observation}
|
|
||||||
task = "Dummy Example Task Dataset"
|
task = "Dummy Example Task Dataset"
|
||||||
|
frame = {**action_sent, **observation, "task": task}
|
||||||
|
|
||||||
dataset.add_frame(frame, task)
|
dataset.add_frame(frame)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
print("Disconnecting Teleop Devices and LeKiwi Client")
|
print("Disconnecting Teleop Devices and LeKiwi Client")
|
||||||
|
|||||||
@@ -218,8 +218,8 @@ def record_loop(
|
|||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||||
frame = {**observation_frame, **action_frame}
|
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||||
dataset.add_frame(frame, task=single_task)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_data:
|
if display_data:
|
||||||
for obs, val in observation.items():
|
for obs, val in observation.items():
|
||||||
|
|||||||
@@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
|||||||
|
|
||||||
new_frame[key] = value
|
new_frame[key] = value
|
||||||
|
|
||||||
new_dataset.add_frame(new_frame, task=task)
|
new_frame["task"] = task
|
||||||
|
new_dataset.add_frame(new_frame)
|
||||||
|
|
||||||
if frame["episode_index"].item() != prev_episode_index:
|
if frame["episode_index"].item() != prev_episode_index:
|
||||||
# Save the episode
|
# Save the episode
|
||||||
|
|||||||
@@ -2132,7 +2132,8 @@ def record_dataset(env, policy, cfg):
|
|||||||
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
||||||
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
||||||
)
|
)
|
||||||
dataset.add_frame(frame, task=cfg.task)
|
frame["task"] = cfg.task
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
# Maintain consistent timing
|
# Maintain consistent timing
|
||||||
if cfg.fps:
|
if cfg.fps:
|
||||||
|
|||||||
@@ -596,7 +596,8 @@ class ReplayBuffer:
|
|||||||
frame_dict[f"complementary_info.{key}"] = val
|
frame_dict[f"complementary_info.{key}"] = val
|
||||||
|
|
||||||
# Add to the dataset's buffer
|
# Add to the dataset's buffer
|
||||||
lerobot_dataset.add_frame(frame_dict, task=task_name)
|
frame_dict["task"] = task_name
|
||||||
|
lerobot_dataset.add_frame(frame_dict)
|
||||||
|
|
||||||
# Move to next frame
|
# Move to next frame
|
||||||
frame_idx_in_episode += 1
|
frame_idx_in_episode += 1
|
||||||
|
|||||||
@@ -37,7 +37,9 @@ from lerobot.common.datasets.utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.envs.factory import make_env_config
|
from lerobot.common.envs.factory import make_env_config
|
||||||
from lerobot.common.policies.factory import make_policy_config
|
from lerobot.common.policies.factory import make_policy_config
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
from lerobot.common.robots import make_robot_from_config
|
||||||
|
from lerobot.common.datasets.utils import hw_to_dataset_features
|
||||||
|
from tests.mocks.mock_robot import MockRobotConfig
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
@@ -66,9 +68,12 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|||||||
objects have the same sets of attributes defined.
|
objects have the same sets of attributes defined.
|
||||||
"""
|
"""
|
||||||
# Instantiate both ways
|
# Instantiate both ways
|
||||||
robot = make_robot("koch", mock=True)
|
robot = make_robot_from_config(MockRobotConfig())
|
||||||
|
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||||
|
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
||||||
|
dataset_features = {**action_features, **obs_features}
|
||||||
root_create = tmp_path / "create"
|
root_create = tmp_path / "create"
|
||||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create)
|
||||||
|
|
||||||
root_init = tmp_path / "init"
|
root_init = tmp_path / "init"
|
||||||
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user