mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
Feat(dataset_tools.py) Add modify tasks tool (#2875)
* feat(datasets): add modify_tasks function for in-place task editing Add a new utility function to modify tasks in LeRobotDataset in-place. This allows users to: - Set a single task for all episodes - Set specific tasks for individual episodes - Combine a default task with per-episode overrides * feat(edit-dataset): add CLI support for modify_tasks operation Integrate the modify_tasks function into lerobot_edit_dataset CLI. Users can now modify dataset tasks via command line: Supports setting a default task, per-episode tasks, or both combined. * test(datasets): add tests for modify_tasks function Add comprehensive test coverage for the modify_tasks utility: - Single task for all episodes - Episode-specific task assignment - Default task with per-episode overrides - Error handling for missing/invalid arguments - Verification of task_index correctness - In-place modification behavior - Metadata preservation * respond to copilot review
This commit is contained in:
@@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
@@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
|
||||
|
||||
def test_modify_tasks_single_task_for_all(sample_dataset):
|
||||
"""Test setting a single task for all episodes."""
|
||||
new_task = "Pick up the cube and place it"
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, new_task=new_task)
|
||||
|
||||
# Verify all episodes have the new task
|
||||
assert len(modified_dataset.meta.tasks) == 1
|
||||
assert new_task in modified_dataset.meta.tasks.index
|
||||
|
||||
# Verify task_index is 0 for all frames (only one task)
|
||||
for i in range(len(modified_dataset)):
|
||||
item = modified_dataset[i]
|
||||
assert item["task_index"].item() == 0
|
||||
assert item["task"] == new_task
|
||||
|
||||
|
||||
def test_modify_tasks_episode_specific(sample_dataset):
|
||||
"""Test setting different tasks for specific episodes."""
|
||||
episode_tasks = {
|
||||
0: "Task A",
|
||||
1: "Task B",
|
||||
2: "Task A",
|
||||
3: "Task C",
|
||||
4: "Task B",
|
||||
}
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||
|
||||
# Verify correct number of unique tasks
|
||||
unique_tasks = set(episode_tasks.values())
|
||||
assert len(modified_dataset.meta.tasks) == len(unique_tasks)
|
||||
|
||||
# Verify each episode has the correct task
|
||||
for ep_idx, expected_task in episode_tasks.items():
|
||||
ep_data = modified_dataset.meta.episodes[ep_idx]
|
||||
assert ep_data["tasks"][0] == expected_task
|
||||
|
||||
|
||||
def test_modify_tasks_default_with_overrides(sample_dataset):
|
||||
"""Test setting a default task with specific overrides."""
|
||||
default_task = "Default task"
|
||||
override_task = "Special task"
|
||||
episode_tasks = {2: override_task, 4: override_task}
|
||||
|
||||
modified_dataset = modify_tasks(
|
||||
sample_dataset,
|
||||
new_task=default_task,
|
||||
episode_tasks=episode_tasks,
|
||||
)
|
||||
|
||||
# Verify correct number of unique tasks
|
||||
assert len(modified_dataset.meta.tasks) == 2
|
||||
assert default_task in modified_dataset.meta.tasks.index
|
||||
assert override_task in modified_dataset.meta.tasks.index
|
||||
|
||||
# Verify episodes have correct tasks
|
||||
for ep_idx in range(5):
|
||||
ep_data = modified_dataset.meta.episodes[ep_idx]
|
||||
if ep_idx in episode_tasks:
|
||||
assert ep_data["tasks"][0] == override_task
|
||||
else:
|
||||
assert ep_data["tasks"][0] == default_task
|
||||
|
||||
|
||||
def test_modify_tasks_no_task_specified(sample_dataset):
|
||||
"""Test error when no task is specified."""
|
||||
with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"):
|
||||
modify_tasks(sample_dataset)
|
||||
|
||||
|
||||
def test_modify_tasks_invalid_episode_indices(sample_dataset):
|
||||
"""Test error with invalid episode indices."""
|
||||
with pytest.raises(ValueError, match="Invalid episode indices"):
|
||||
modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"})
|
||||
|
||||
|
||||
def test_modify_tasks_updates_info_json(sample_dataset):
|
||||
"""Test that total_tasks is updated in info.json."""
|
||||
episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"}
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||
|
||||
# Verify total_tasks is updated
|
||||
assert modified_dataset.meta.total_tasks == 3
|
||||
|
||||
|
||||
def test_modify_tasks_preserves_other_metadata(sample_dataset):
|
||||
"""Test that modifying tasks preserves other metadata."""
|
||||
original_frames = sample_dataset.meta.total_frames
|
||||
original_episodes = sample_dataset.meta.total_episodes
|
||||
original_fps = sample_dataset.meta.fps
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
|
||||
|
||||
# Verify other metadata is preserved
|
||||
assert modified_dataset.meta.total_frames == original_frames
|
||||
assert modified_dataset.meta.total_episodes == original_episodes
|
||||
assert modified_dataset.meta.fps == original_fps
|
||||
|
||||
|
||||
def test_modify_tasks_task_index_correct(sample_dataset):
|
||||
"""Test that task_index values are correct in data files."""
|
||||
# Create tasks that will have predictable indices (sorted alphabetically)
|
||||
episode_tasks = {
|
||||
0: "Alpha task", # Will be index 0
|
||||
1: "Beta task", # Will be index 1
|
||||
2: "Alpha task", # Will be index 0
|
||||
3: "Gamma task", # Will be index 2
|
||||
4: "Beta task", # Will be index 1
|
||||
}
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||
|
||||
# Verify task indices are correct
|
||||
task_to_expected_idx = {
|
||||
"Alpha task": 0,
|
||||
"Beta task": 1,
|
||||
"Gamma task": 2,
|
||||
}
|
||||
|
||||
for i in range(len(modified_dataset)):
|
||||
item = modified_dataset[i]
|
||||
ep_idx = item["episode_index"].item()
|
||||
expected_task = episode_tasks[ep_idx]
|
||||
expected_idx = task_to_expected_idx[expected_task]
|
||||
assert item["task_index"].item() == expected_idx
|
||||
assert item["task"] == expected_task
|
||||
|
||||
|
||||
def test_modify_tasks_in_place(sample_dataset):
|
||||
"""Test that modify_tasks modifies the dataset in-place."""
|
||||
original_root = sample_dataset.root
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
|
||||
|
||||
# Verify same instance is returned and root is unchanged
|
||||
assert modified_dataset is sample_dataset
|
||||
assert modified_dataset.root == original_root
|
||||
|
||||
|
||||
def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset):
|
||||
"""Test that original tasks are kept when using episode_tasks without new_task."""
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
|
||||
# Ensure episodes metadata is loaded
|
||||
if sample_dataset.meta.episodes is None:
|
||||
sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root)
|
||||
|
||||
# Get original tasks for episodes not being overridden
|
||||
original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0]
|
||||
original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0]
|
||||
|
||||
# Only override episodes 2, 3, 4
|
||||
episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"}
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||
|
||||
# Verify original tasks are kept for episodes 0 and 1
|
||||
assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0
|
||||
assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1
|
||||
|
||||
# Verify new tasks for overridden episodes
|
||||
assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A"
|
||||
assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B"
|
||||
assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A"
|
||||
|
||||
|
||||
def test_convert_image_to_video_dataset(tmp_path):
|
||||
"""Test converting lerobot/pusht_image dataset to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
Reference in New Issue
Block a user