Feat(dataset_tools.py) Add modify tasks tool (#2875)

* feat(datasets): add modify_tasks function for in-place task editing

Add a new utility function to modify tasks in LeRobotDataset in-place.
This allows users to:
- Set a single task for all episodes
- Set specific tasks for individual episodes
- Combine a default task with per-episode overrides

* feat(edit-dataset): add CLI support for modify_tasks operation

Integrate the modify_tasks function into lerobot_edit_dataset CLI.
Users can now modify dataset tasks via command line:
Supports setting a default task, per-episode tasks, or both combined.

* test(datasets): add tests for modify_tasks function

Add comprehensive test coverage for the modify_tasks utility:
- Single task for all episodes
- Episode-specific task assignment
- Default task with per-episode overrides
- Error handling for missing/invalid arguments
- Verification of task_index correctness
- In-place modification behavior
- Metadata preservation

* respond to copilot review
This commit is contained in:
Michel Aractingi
2026-01-30 13:19:42 +01:00
committed by GitHub
parent 04cbf669cf
commit ec04b7ce3a
3 changed files with 374 additions and 3 deletions
+126
View File
@@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
def modify_tasks(
dataset: LeRobotDataset,
new_task: str | None = None,
episode_tasks: dict[int, str] | None = None,
) -> LeRobotDataset:
"""Modify tasks in a LeRobotDataset.
This function allows you to either:
1. Set a single task for the entire dataset (using `new_task`)
2. Set specific tasks for specific episodes (using `episode_tasks`)
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
specific episodes.
The dataset is modified in-place, updating only the task-related files:
- meta/tasks.parquet
- data/**/*.parquet (task_index column)
- meta/episodes/**/*.parquet (tasks column)
- meta/info.json (total_tasks)
Args:
dataset: The source LeRobotDataset to modify.
new_task: A single task string to apply to all episodes. If None and episode_tasks
is also None, raises an error.
episode_tasks: Optional dict mapping episode indices to their task strings.
Overrides `new_task` for specific episodes.
Examples:
Set a single task for all episodes:
dataset = modify_tasks(dataset, new_task="Pick up the cube")
Set different tasks for specific episodes:
dataset = modify_tasks(
dataset,
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
)
Set a default task with overrides:
dataset = modify_tasks(
dataset,
new_task="Default task",
episode_tasks={5: "Special task for episode 5"}
)
"""
if new_task is None and episode_tasks is None:
raise ValueError("Must specify at least one of new_task or episode_tasks")
if episode_tasks is not None:
valid_indices = set(range(dataset.meta.total_episodes))
invalid = set(episode_tasks.keys()) - valid_indices
if invalid:
raise ValueError(f"Invalid episode indices: {invalid}")
# Ensure episodes metadata is loaded
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.root)
# Build the mapping from episode index to task string
episode_to_task: dict[int, str] = {}
for ep_idx in range(dataset.meta.total_episodes):
if episode_tasks and ep_idx in episode_tasks:
episode_to_task[ep_idx] = episode_tasks[ep_idx]
elif new_task is not None:
episode_to_task[ep_idx] = new_task
else:
# Keep original task if not overridden and no default provided
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
if not original_tasks:
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
episode_to_task[ep_idx] = original_tasks[0]
# Collect all unique tasks and create new task mapping
unique_tasks = sorted(set(episode_to_task.values()))
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
logging.info(f"Modifying tasks in {dataset.repo_id}")
logging.info(f"New tasks: {unique_tasks}")
root = dataset.root
# Update data files - modify task_index column
logging.info("Updating data files...")
data_dir = root / DATA_DIR
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
df = pd.read_parquet(parquet_path)
# Build a mapping from episode_index to new task_index for rows in this file
episode_indices_in_file = df["episode_index"].unique()
ep_to_new_task_idx = {
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
}
# Update task_index column
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
df.to_parquet(parquet_path, index=False)
# Update episodes metadata - modify tasks column
logging.info("Updating episodes metadata...")
episodes_dir = root / "meta" / "episodes"
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
df = pd.read_parquet(parquet_path)
# Update tasks column
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
df.to_parquet(parquet_path, index=False)
# Write new tasks.parquet
write_tasks(new_task_df, root)
# Update info.json
dataset.meta.info["total_tasks"] = len(unique_tasks)
write_info(dataset.meta.info, root)
# Reload metadata to reflect changes
dataset.meta.tasks = new_task_df
dataset.meta.episodes = load_episodes(root)
logging.info(f"Tasks: {unique_tasks}")
return dataset
def convert_image_to_video_dataset( def convert_image_to_video_dataset(
dataset: LeRobotDataset, dataset: LeRobotDataset,
output_dir: Path, output_dir: Path,
+79 -3
View File
@@ -18,7 +18,7 @@
Edit LeRobot datasets using various transformation tools. Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets, This script allows you to delete episodes, split datasets, merge datasets,
remove features, and convert image datasets to video format. remove features, modify tasks, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset. When new_repo_id is specified, creates a new dataset.
Usage Examples: Usage Examples:
@@ -66,6 +66,25 @@ Remove camera feature:
--operation.type remove_feature \ --operation.type remove_feature \
--operation.feature_names "['observation.images.top']" --operation.feature_names "['observation.images.top']"
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.new_task "Pick up the cube and place it"
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type modify_tasks \
--operation.new_task "Default task" \
--operation.episode_tasks '{"5": "Special task for episode 5"}'
Convert image dataset to video format and save locally: Convert image dataset to video format and save locally:
python -m lerobot.scripts.lerobot_edit_dataset \ python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \ --repo_id lerobot/pusht_image \
@@ -100,6 +119,7 @@ from lerobot.datasets.dataset_tools import (
convert_image_to_video_dataset, convert_image_to_video_dataset,
delete_episodes, delete_episodes,
merge_datasets, merge_datasets,
modify_tasks,
remove_feature, remove_feature,
split_dataset, split_dataset,
) )
@@ -132,6 +152,13 @@ class RemoveFeatureConfig:
feature_names: list[str] | None = None feature_names: list[str] | None = None
@dataclass
class ModifyTasksConfig:
type: str = "modify_tasks"
new_task: str | None = None
episode_tasks: dict[str, str] | None = None
@dataclass @dataclass
class ConvertImageToVideoConfig: class ConvertImageToVideoConfig:
type: str = "convert_image_to_video" type: str = "convert_image_to_video"
@@ -151,7 +178,12 @@ class ConvertImageToVideoConfig:
class EditDatasetConfig: class EditDatasetConfig:
repo_id: str repo_id: str
operation: ( operation: (
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig DeleteEpisodesConfig
| SplitConfig
| MergeConfig
| RemoveFeatureConfig
| ModifyTasksConfig
| ConvertImageToVideoConfig
) )
root: str | None = None root: str | None = None
new_repo_id: str | None = None new_repo_id: str | None = None
@@ -296,6 +328,48 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, ModifyTasksConfig):
raise ValueError("Operation config must be ModifyTasksConfig")
new_task = cfg.operation.new_task
episode_tasks_raw = cfg.operation.episode_tasks
if new_task is None and episode_tasks_raw is None:
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
# Warn about in-place modification behavior
if cfg.new_repo_id is not None:
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
episode_tasks: dict[int, str] | None = None
if episode_tasks_raw is not None:
episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()}
logging.info(f"Modifying tasks in {cfg.repo_id}")
if new_task:
logging.info(f" Default task: '{new_task}'")
if episode_tasks:
logging.info(f" Episode-specific tasks: {episode_tasks}")
modified_dataset = modify_tasks(
dataset,
new_task=new_task,
episode_tasks=episode_tasks,
)
logging.info(f"Dataset modified at {dataset.root}")
logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {cfg.repo_id}")
modified_dataset.push_to_hub()
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
# Note: Parser may create any config type with the right fields, so we access fields directly # Note: Parser may create any config type with the right fields, so we access fields directly
# instead of checking isinstance() # instead of checking isinstance()
@@ -371,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_merge(cfg) handle_merge(cfg)
elif operation_type == "remove_feature": elif operation_type == "remove_feature":
handle_remove_feature(cfg) handle_remove_feature(cfg)
elif operation_type == "modify_tasks":
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video": elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg) handle_convert_image_to_video(cfg)
else: else:
raise ValueError( raise ValueError(
f"Unknown operation type: {operation_type}\n" f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video" f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
) )
+169
View File
@@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import (
delete_episodes, delete_episodes,
merge_datasets, merge_datasets,
modify_features, modify_features,
modify_tasks,
remove_feature, remove_feature,
split_dataset, split_dataset,
) )
@@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
assert "reward" in modified_dataset.meta.features assert "reward" in modified_dataset.meta.features
def test_modify_tasks_single_task_for_all(sample_dataset):
"""Test setting a single task for all episodes."""
new_task = "Pick up the cube and place it"
modified_dataset = modify_tasks(sample_dataset, new_task=new_task)
# Verify all episodes have the new task
assert len(modified_dataset.meta.tasks) == 1
assert new_task in modified_dataset.meta.tasks.index
# Verify task_index is 0 for all frames (only one task)
for i in range(len(modified_dataset)):
item = modified_dataset[i]
assert item["task_index"].item() == 0
assert item["task"] == new_task
def test_modify_tasks_episode_specific(sample_dataset):
"""Test setting different tasks for specific episodes."""
episode_tasks = {
0: "Task A",
1: "Task B",
2: "Task A",
3: "Task C",
4: "Task B",
}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify correct number of unique tasks
unique_tasks = set(episode_tasks.values())
assert len(modified_dataset.meta.tasks) == len(unique_tasks)
# Verify each episode has the correct task
for ep_idx, expected_task in episode_tasks.items():
ep_data = modified_dataset.meta.episodes[ep_idx]
assert ep_data["tasks"][0] == expected_task
def test_modify_tasks_default_with_overrides(sample_dataset):
"""Test setting a default task with specific overrides."""
default_task = "Default task"
override_task = "Special task"
episode_tasks = {2: override_task, 4: override_task}
modified_dataset = modify_tasks(
sample_dataset,
new_task=default_task,
episode_tasks=episode_tasks,
)
# Verify correct number of unique tasks
assert len(modified_dataset.meta.tasks) == 2
assert default_task in modified_dataset.meta.tasks.index
assert override_task in modified_dataset.meta.tasks.index
# Verify episodes have correct tasks
for ep_idx in range(5):
ep_data = modified_dataset.meta.episodes[ep_idx]
if ep_idx in episode_tasks:
assert ep_data["tasks"][0] == override_task
else:
assert ep_data["tasks"][0] == default_task
def test_modify_tasks_no_task_specified(sample_dataset):
"""Test error when no task is specified."""
with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"):
modify_tasks(sample_dataset)
def test_modify_tasks_invalid_episode_indices(sample_dataset):
"""Test error with invalid episode indices."""
with pytest.raises(ValueError, match="Invalid episode indices"):
modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"})
def test_modify_tasks_updates_info_json(sample_dataset):
"""Test that total_tasks is updated in info.json."""
episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify total_tasks is updated
assert modified_dataset.meta.total_tasks == 3
def test_modify_tasks_preserves_other_metadata(sample_dataset):
"""Test that modifying tasks preserves other metadata."""
original_frames = sample_dataset.meta.total_frames
original_episodes = sample_dataset.meta.total_episodes
original_fps = sample_dataset.meta.fps
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
# Verify other metadata is preserved
assert modified_dataset.meta.total_frames == original_frames
assert modified_dataset.meta.total_episodes == original_episodes
assert modified_dataset.meta.fps == original_fps
def test_modify_tasks_task_index_correct(sample_dataset):
"""Test that task_index values are correct in data files."""
# Create tasks that will have predictable indices (sorted alphabetically)
episode_tasks = {
0: "Alpha task", # Will be index 0
1: "Beta task", # Will be index 1
2: "Alpha task", # Will be index 0
3: "Gamma task", # Will be index 2
4: "Beta task", # Will be index 1
}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify task indices are correct
task_to_expected_idx = {
"Alpha task": 0,
"Beta task": 1,
"Gamma task": 2,
}
for i in range(len(modified_dataset)):
item = modified_dataset[i]
ep_idx = item["episode_index"].item()
expected_task = episode_tasks[ep_idx]
expected_idx = task_to_expected_idx[expected_task]
assert item["task_index"].item() == expected_idx
assert item["task"] == expected_task
def test_modify_tasks_in_place(sample_dataset):
"""Test that modify_tasks modifies the dataset in-place."""
original_root = sample_dataset.root
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
# Verify same instance is returned and root is unchanged
assert modified_dataset is sample_dataset
assert modified_dataset.root == original_root
def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset):
"""Test that original tasks are kept when using episode_tasks without new_task."""
from lerobot.datasets.utils import load_episodes
# Ensure episodes metadata is loaded
if sample_dataset.meta.episodes is None:
sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root)
# Get original tasks for episodes not being overridden
original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0]
original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0]
# Only override episodes 2, 3, 4
episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"}
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
# Verify original tasks are kept for episodes 0 and 1
assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0
assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1
# Verify new tasks for overridden episodes
assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A"
assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B"
assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A"
def test_convert_image_to_video_dataset(tmp_path): def test_convert_image_to_video_dataset(tmp_path):
"""Test converting lerobot/pusht_image dataset to video format.""" """Test converting lerobot/pusht_image dataset to video format."""
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset