diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index e2928e2a6..123d455c6 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024 BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB +def modify_tasks( + dataset: LeRobotDataset, + new_task: str | None = None, + episode_tasks: dict[int, str] | None = None, +) -> LeRobotDataset: + """Modify tasks in a LeRobotDataset. + + This function allows you to either: + 1. Set a single task for the entire dataset (using `new_task`) + 2. Set specific tasks for specific episodes (using `episode_tasks`) + + You can combine both: `new_task` sets the default, and `episode_tasks` overrides + specific episodes. + + The dataset is modified in-place, updating only the task-related files: + - meta/tasks.parquet + - data/**/*.parquet (task_index column) + - meta/episodes/**/*.parquet (tasks column) + - meta/info.json (total_tasks) + + Args: + dataset: The source LeRobotDataset to modify. + new_task: A single task string to apply to all episodes. If None and episode_tasks + is also None, raises an error. + episode_tasks: Optional dict mapping episode indices to their task strings. + Overrides `new_task` for specific episodes. + + + Examples: + Set a single task for all episodes: + dataset = modify_tasks(dataset, new_task="Pick up the cube") + + Set different tasks for specific episodes: + dataset = modify_tasks( + dataset, + episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"} + ) + + Set a default task with overrides: + dataset = modify_tasks( + dataset, + new_task="Default task", + episode_tasks={5: "Special task for episode 5"} + ) + """ + if new_task is None and episode_tasks is None: + raise ValueError("Must specify at least one of new_task or episode_tasks") + + if episode_tasks is not None: + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_tasks.keys()) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + # Ensure episodes metadata is loaded + if dataset.meta.episodes is None: + dataset.meta.episodes = load_episodes(dataset.root) + + # Build the mapping from episode index to task string + episode_to_task: dict[int, str] = {} + for ep_idx in range(dataset.meta.total_episodes): + if episode_tasks and ep_idx in episode_tasks: + episode_to_task[ep_idx] = episode_tasks[ep_idx] + elif new_task is not None: + episode_to_task[ep_idx] = new_task + else: + # Keep original task if not overridden and no default provided + original_tasks = dataset.meta.episodes[ep_idx]["tasks"] + if not original_tasks: + raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided") + episode_to_task[ep_idx] = original_tasks[0] + + # Collect all unique tasks and create new task mapping + unique_tasks = sorted(set(episode_to_task.values())) + new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks) + task_to_index = {task: idx for idx, task in enumerate(unique_tasks)} + + logging.info(f"Modifying tasks in {dataset.repo_id}") + logging.info(f"New tasks: {unique_tasks}") + + root = dataset.root + + # Update data files - modify task_index column + logging.info("Updating data files...") + data_dir = root / DATA_DIR + + for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"): + df = pd.read_parquet(parquet_path) + + # Build a mapping from episode_index to new task_index for rows in this file + episode_indices_in_file = df["episode_index"].unique() + ep_to_new_task_idx = { + ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file + } + + # Update task_index column + df["task_index"] = df["episode_index"].map(ep_to_new_task_idx) + df.to_parquet(parquet_path, index=False) + + # Update episodes metadata - modify tasks column + logging.info("Updating episodes metadata...") + episodes_dir = root / "meta" / "episodes" + + for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"): + df = pd.read_parquet(parquet_path) + + # Update tasks column + df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]]) + df.to_parquet(parquet_path, index=False) + + # Write new tasks.parquet + write_tasks(new_task_df, root) + + # Update info.json + dataset.meta.info["total_tasks"] = len(unique_tasks) + write_info(dataset.meta.info, root) + + # Reload metadata to reflect changes + dataset.meta.tasks = new_task_df + dataset.meta.episodes = load_episodes(root) + + logging.info(f"Tasks: {unique_tasks}") + + return dataset + + def convert_image_to_video_dataset( dataset: LeRobotDataset, output_dir: Path, diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 4ba6ce44f..2ca9c520d 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -18,7 +18,7 @@ Edit LeRobot datasets using various transformation tools. This script allows you to delete episodes, split datasets, merge datasets, -remove features, and convert image datasets to video format. +remove features, modify tasks, and convert image datasets to video format. When new_repo_id is specified, creates a new dataset. Usage Examples: @@ -66,6 +66,25 @@ Remove camera feature: --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" +Modify tasks - set a single task for all episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.new_task "Pick up the cube and place it" + +Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}' + +Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.new_task "Default task" \ + --operation.episode_tasks '{"5": "Special task for episode 5"}' + Convert image dataset to video format and save locally: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ @@ -100,6 +119,7 @@ from lerobot.datasets.dataset_tools import ( convert_image_to_video_dataset, delete_episodes, merge_datasets, + modify_tasks, remove_feature, split_dataset, ) @@ -132,6 +152,13 @@ class RemoveFeatureConfig: feature_names: list[str] | None = None +@dataclass +class ModifyTasksConfig: + type: str = "modify_tasks" + new_task: str | None = None + episode_tasks: dict[str, str] | None = None + + @dataclass class ConvertImageToVideoConfig: type: str = "convert_image_to_video" @@ -151,7 +178,12 @@ class ConvertImageToVideoConfig: class EditDatasetConfig: repo_id: str operation: ( - DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig + DeleteEpisodesConfig + | SplitConfig + | MergeConfig + | RemoveFeatureConfig + | ModifyTasksConfig + | ConvertImageToVideoConfig ) root: str | None = None new_repo_id: str | None = None @@ -296,6 +328,48 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() +def handle_modify_tasks(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, ModifyTasksConfig): + raise ValueError("Operation config must be ModifyTasksConfig") + + new_task = cfg.operation.new_task + episode_tasks_raw = cfg.operation.episode_tasks + + if new_task is None and episode_tasks_raw is None: + raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation") + + # Warn about in-place modification behavior + if cfg.new_repo_id is not None: + logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.") + + # Convert episode_tasks keys from string to int if needed (CLI passes strings) + episode_tasks: dict[int, str] | None = None + if episode_tasks_raw is not None: + episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()} + + logging.info(f"Modifying tasks in {cfg.repo_id}") + if new_task: + logging.info(f" Default task: '{new_task}'") + if episode_tasks: + logging.info(f" Episode-specific tasks: {episode_tasks}") + + modified_dataset = modify_tasks( + dataset, + new_task=new_task, + episode_tasks=episode_tasks, + ) + + logging.info(f"Dataset modified at {dataset.root}") + logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {cfg.repo_id}") + modified_dataset.push_to_hub() + + def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: # Note: Parser may create any config type with the right fields, so we access fields directly # instead of checking isinstance() @@ -371,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_merge(cfg) elif operation_type == "remove_feature": handle_remove_feature(cfg) + elif operation_type == "modify_tasks": + handle_modify_tasks(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) else: raise ValueError( f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video" + f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video" ) diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 35a369de9..1de199630 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import ( delete_episodes, merge_datasets, modify_features, + modify_tasks, remove_feature, split_dataset, ) @@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): assert "reward" in modified_dataset.meta.features +def test_modify_tasks_single_task_for_all(sample_dataset): + """Test setting a single task for all episodes.""" + new_task = "Pick up the cube and place it" + + modified_dataset = modify_tasks(sample_dataset, new_task=new_task) + + # Verify all episodes have the new task + assert len(modified_dataset.meta.tasks) == 1 + assert new_task in modified_dataset.meta.tasks.index + + # Verify task_index is 0 for all frames (only one task) + for i in range(len(modified_dataset)): + item = modified_dataset[i] + assert item["task_index"].item() == 0 + assert item["task"] == new_task + + +def test_modify_tasks_episode_specific(sample_dataset): + """Test setting different tasks for specific episodes.""" + episode_tasks = { + 0: "Task A", + 1: "Task B", + 2: "Task A", + 3: "Task C", + 4: "Task B", + } + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify correct number of unique tasks + unique_tasks = set(episode_tasks.values()) + assert len(modified_dataset.meta.tasks) == len(unique_tasks) + + # Verify each episode has the correct task + for ep_idx, expected_task in episode_tasks.items(): + ep_data = modified_dataset.meta.episodes[ep_idx] + assert ep_data["tasks"][0] == expected_task + + +def test_modify_tasks_default_with_overrides(sample_dataset): + """Test setting a default task with specific overrides.""" + default_task = "Default task" + override_task = "Special task" + episode_tasks = {2: override_task, 4: override_task} + + modified_dataset = modify_tasks( + sample_dataset, + new_task=default_task, + episode_tasks=episode_tasks, + ) + + # Verify correct number of unique tasks + assert len(modified_dataset.meta.tasks) == 2 + assert default_task in modified_dataset.meta.tasks.index + assert override_task in modified_dataset.meta.tasks.index + + # Verify episodes have correct tasks + for ep_idx in range(5): + ep_data = modified_dataset.meta.episodes[ep_idx] + if ep_idx in episode_tasks: + assert ep_data["tasks"][0] == override_task + else: + assert ep_data["tasks"][0] == default_task + + +def test_modify_tasks_no_task_specified(sample_dataset): + """Test error when no task is specified.""" + with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"): + modify_tasks(sample_dataset) + + +def test_modify_tasks_invalid_episode_indices(sample_dataset): + """Test error with invalid episode indices.""" + with pytest.raises(ValueError, match="Invalid episode indices"): + modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"}) + + +def test_modify_tasks_updates_info_json(sample_dataset): + """Test that total_tasks is updated in info.json.""" + episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"} + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify total_tasks is updated + assert modified_dataset.meta.total_tasks == 3 + + +def test_modify_tasks_preserves_other_metadata(sample_dataset): + """Test that modifying tasks preserves other metadata.""" + original_frames = sample_dataset.meta.total_frames + original_episodes = sample_dataset.meta.total_episodes + original_fps = sample_dataset.meta.fps + + modified_dataset = modify_tasks(sample_dataset, new_task="New task") + + # Verify other metadata is preserved + assert modified_dataset.meta.total_frames == original_frames + assert modified_dataset.meta.total_episodes == original_episodes + assert modified_dataset.meta.fps == original_fps + + +def test_modify_tasks_task_index_correct(sample_dataset): + """Test that task_index values are correct in data files.""" + # Create tasks that will have predictable indices (sorted alphabetically) + episode_tasks = { + 0: "Alpha task", # Will be index 0 + 1: "Beta task", # Will be index 1 + 2: "Alpha task", # Will be index 0 + 3: "Gamma task", # Will be index 2 + 4: "Beta task", # Will be index 1 + } + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify task indices are correct + task_to_expected_idx = { + "Alpha task": 0, + "Beta task": 1, + "Gamma task": 2, + } + + for i in range(len(modified_dataset)): + item = modified_dataset[i] + ep_idx = item["episode_index"].item() + expected_task = episode_tasks[ep_idx] + expected_idx = task_to_expected_idx[expected_task] + assert item["task_index"].item() == expected_idx + assert item["task"] == expected_task + + +def test_modify_tasks_in_place(sample_dataset): + """Test that modify_tasks modifies the dataset in-place.""" + original_root = sample_dataset.root + + modified_dataset = modify_tasks(sample_dataset, new_task="New task") + + # Verify same instance is returned and root is unchanged + assert modified_dataset is sample_dataset + assert modified_dataset.root == original_root + + +def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset): + """Test that original tasks are kept when using episode_tasks without new_task.""" + from lerobot.datasets.utils import load_episodes + + # Ensure episodes metadata is loaded + if sample_dataset.meta.episodes is None: + sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root) + + # Get original tasks for episodes not being overridden + original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0] + original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0] + + # Only override episodes 2, 3, 4 + episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"} + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify original tasks are kept for episodes 0 and 1 + assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0 + assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1 + + # Verify new tasks for overridden episodes + assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A" + assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B" + assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A" + + def test_convert_image_to_video_dataset(tmp_path): """Test converting lerobot/pusht_image dataset to video format.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset