Feat(dataset_tools.py) Add modify tasks tool (#2875)

* feat(datasets): add modify_tasks function for in-place task editing

Add a new utility function to modify tasks in LeRobotDataset in-place.
This allows users to:
- Set a single task for all episodes
- Set specific tasks for individual episodes
- Combine a default task with per-episode overrides

* feat(edit-dataset): add CLI support for modify_tasks operation

Integrate the modify_tasks function into lerobot_edit_dataset CLI.
Users can now modify dataset tasks via command line:
Supports setting a default task, per-episode tasks, or both combined.

* test(datasets): add tests for modify_tasks function

Add comprehensive test coverage for the modify_tasks utility:
- Single task for all episodes
- Episode-specific task assignment
- Default task with per-episode overrides
- Error handling for missing/invalid arguments
- Verification of task_index correctness
- In-place modification behavior
- Metadata preservation

* respond to copilot review
This commit is contained in:
Michel Aractingi
2026-01-30 13:19:42 +01:00
committed by GitHub
parent 04cbf669cf
commit ec04b7ce3a
3 changed files with 374 additions and 3 deletions
+169
View File
@@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import (
delete_episodes,
merge_datasets,
modify_features,
modify_tasks,
remove_feature,
split_dataset,
)
@@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
assert "reward" in modified_dataset.meta.features
def test_modify_tasks_single_task_for_all(sample_dataset):
"""Test setting a single task for all episodes."""
new_task = "Pick up the cube and place it"
modified_dataset = modify_tasks(sample_dataset, new_task=new_task)
# Verify all episodes have the new task
assert len(modified_dataset.meta.tasks) == 1
assert new_task in modified_dataset.meta.tasks.index
# Verify task_index is 0 for all frames (only one task)
for i in range(len(modified_dataset)):
item = modified_dataset[i]
assert item["task_index"].item() == 0
assert item["task"] == new_task
def test_modify_tasks_episode_specific(sample_dataset):
"""Test setting different tasks for specific episodes."""
episode_tasks = {
0: "Task A",
1: "Task B",
2: "Task A",
3: "Task C",
4: "Task B",
}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify correct number of unique tasks
unique_tasks = set(episode_tasks.values())
assert len(modified_dataset.meta.tasks) == len(unique_tasks)
# Verify each episode has the correct task
for ep_idx, expected_task in episode_tasks.items():
ep_data = modified_dataset.meta.episodes[ep_idx]
assert ep_data["tasks"][0] == expected_task
def test_modify_tasks_default_with_overrides(sample_dataset):
"""Test setting a default task with specific overrides."""
default_task = "Default task"
override_task = "Special task"
episode_tasks = {2: override_task, 4: override_task}
modified_dataset = modify_tasks(
sample_dataset,
new_task=default_task,
episode_tasks=episode_tasks,
)
# Verify correct number of unique tasks
assert len(modified_dataset.meta.tasks) == 2
assert default_task in modified_dataset.meta.tasks.index
assert override_task in modified_dataset.meta.tasks.index
# Verify episodes have correct tasks
for ep_idx in range(5):
ep_data = modified_dataset.meta.episodes[ep_idx]
if ep_idx in episode_tasks:
assert ep_data["tasks"][0] == override_task
else:
assert ep_data["tasks"][0] == default_task
def test_modify_tasks_no_task_specified(sample_dataset):
"""Test error when no task is specified."""
with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"):
modify_tasks(sample_dataset)
def test_modify_tasks_invalid_episode_indices(sample_dataset):
"""Test error with invalid episode indices."""
with pytest.raises(ValueError, match="Invalid episode indices"):
modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"})
def test_modify_tasks_updates_info_json(sample_dataset):
"""Test that total_tasks is updated in info.json."""
episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify total_tasks is updated
assert modified_dataset.meta.total_tasks == 3
def test_modify_tasks_preserves_other_metadata(sample_dataset):
"""Test that modifying tasks preserves other metadata."""
original_frames = sample_dataset.meta.total_frames
original_episodes = sample_dataset.meta.total_episodes
original_fps = sample_dataset.meta.fps
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
# Verify other metadata is preserved
assert modified_dataset.meta.total_frames == original_frames
assert modified_dataset.meta.total_episodes == original_episodes
assert modified_dataset.meta.fps == original_fps
def test_modify_tasks_task_index_correct(sample_dataset):
"""Test that task_index values are correct in data files."""
# Create tasks that will have predictable indices (sorted alphabetically)
episode_tasks = {
0: "Alpha task", # Will be index 0
1: "Beta task", # Will be index 1
2: "Alpha task", # Will be index 0
3: "Gamma task", # Will be index 2
4: "Beta task", # Will be index 1
}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify task indices are correct
task_to_expected_idx = {
"Alpha task": 0,
"Beta task": 1,
"Gamma task": 2,
}
for i in range(len(modified_dataset)):
item = modified_dataset[i]
ep_idx = item["episode_index"].item()
expected_task = episode_tasks[ep_idx]
expected_idx = task_to_expected_idx[expected_task]
assert item["task_index"].item() == expected_idx
assert item["task"] == expected_task
def test_modify_tasks_in_place(sample_dataset):
"""Test that modify_tasks modifies the dataset in-place."""
original_root = sample_dataset.root
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
# Verify same instance is returned and root is unchanged
assert modified_dataset is sample_dataset
assert modified_dataset.root == original_root
def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset):
"""Test that original tasks are kept when using episode_tasks without new_task."""
from lerobot.datasets.utils import load_episodes
# Ensure episodes metadata is loaded
if sample_dataset.meta.episodes is None:
sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root)
# Get original tasks for episodes not being overridden
original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0]
original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0]
# Only override episodes 2, 3, 4
episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify original tasks are kept for episodes 0 and 1
assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0
assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1
# Verify new tasks for overridden episodes
assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A"
assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B"
assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A"
def test_convert_image_to_video_dataset(tmp_path):
"""Test converting lerobot/pusht_image dataset to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset