mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
Feat(dataset_tools.py) Add modify tasks tool (#2875)
* feat(datasets): add modify_tasks function for in-place task editing Add a new utility function to modify tasks in LeRobotDataset in-place. This allows users to: - Set a single task for all episodes - Set specific tasks for individual episodes - Combine a default task with per-episode overrides * feat(edit-dataset): add CLI support for modify_tasks operation Integrate the modify_tasks function into lerobot_edit_dataset CLI. Users can now modify dataset tasks via command line: Supports setting a default task, per-episode tasks, or both combined. * test(datasets): add tests for modify_tasks function Add comprehensive test coverage for the modify_tasks utility: - Single task for all episodes - Episode-specific task assignment - Default task with per-episode overrides - Error handling for missing/invalid arguments - Verification of task_index correctness - In-place modification behavior - Metadata preservation * respond to copilot review
This commit is contained in:
@@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024
|
|||||||
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
|
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
|
||||||
|
|
||||||
|
|
||||||
|
def modify_tasks(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
new_task: str | None = None,
|
||||||
|
episode_tasks: dict[int, str] | None = None,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
"""Modify tasks in a LeRobotDataset.
|
||||||
|
|
||||||
|
This function allows you to either:
|
||||||
|
1. Set a single task for the entire dataset (using `new_task`)
|
||||||
|
2. Set specific tasks for specific episodes (using `episode_tasks`)
|
||||||
|
|
||||||
|
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
|
||||||
|
specific episodes.
|
||||||
|
|
||||||
|
The dataset is modified in-place, updating only the task-related files:
|
||||||
|
- meta/tasks.parquet
|
||||||
|
- data/**/*.parquet (task_index column)
|
||||||
|
- meta/episodes/**/*.parquet (tasks column)
|
||||||
|
- meta/info.json (total_tasks)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The source LeRobotDataset to modify.
|
||||||
|
new_task: A single task string to apply to all episodes. If None and episode_tasks
|
||||||
|
is also None, raises an error.
|
||||||
|
episode_tasks: Optional dict mapping episode indices to their task strings.
|
||||||
|
Overrides `new_task` for specific episodes.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Set a single task for all episodes:
|
||||||
|
dataset = modify_tasks(dataset, new_task="Pick up the cube")
|
||||||
|
|
||||||
|
Set different tasks for specific episodes:
|
||||||
|
dataset = modify_tasks(
|
||||||
|
dataset,
|
||||||
|
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
|
||||||
|
)
|
||||||
|
|
||||||
|
Set a default task with overrides:
|
||||||
|
dataset = modify_tasks(
|
||||||
|
dataset,
|
||||||
|
new_task="Default task",
|
||||||
|
episode_tasks={5: "Special task for episode 5"}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
if new_task is None and episode_tasks is None:
|
||||||
|
raise ValueError("Must specify at least one of new_task or episode_tasks")
|
||||||
|
|
||||||
|
if episode_tasks is not None:
|
||||||
|
valid_indices = set(range(dataset.meta.total_episodes))
|
||||||
|
invalid = set(episode_tasks.keys()) - valid_indices
|
||||||
|
if invalid:
|
||||||
|
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||||
|
|
||||||
|
# Ensure episodes metadata is loaded
|
||||||
|
if dataset.meta.episodes is None:
|
||||||
|
dataset.meta.episodes = load_episodes(dataset.root)
|
||||||
|
|
||||||
|
# Build the mapping from episode index to task string
|
||||||
|
episode_to_task: dict[int, str] = {}
|
||||||
|
for ep_idx in range(dataset.meta.total_episodes):
|
||||||
|
if episode_tasks and ep_idx in episode_tasks:
|
||||||
|
episode_to_task[ep_idx] = episode_tasks[ep_idx]
|
||||||
|
elif new_task is not None:
|
||||||
|
episode_to_task[ep_idx] = new_task
|
||||||
|
else:
|
||||||
|
# Keep original task if not overridden and no default provided
|
||||||
|
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
|
||||||
|
if not original_tasks:
|
||||||
|
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
|
||||||
|
episode_to_task[ep_idx] = original_tasks[0]
|
||||||
|
|
||||||
|
# Collect all unique tasks and create new task mapping
|
||||||
|
unique_tasks = sorted(set(episode_to_task.values()))
|
||||||
|
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||||
|
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||||
|
|
||||||
|
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||||
|
logging.info(f"New tasks: {unique_tasks}")
|
||||||
|
|
||||||
|
root = dataset.root
|
||||||
|
|
||||||
|
# Update data files - modify task_index column
|
||||||
|
logging.info("Updating data files...")
|
||||||
|
data_dir = root / DATA_DIR
|
||||||
|
|
||||||
|
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
|
||||||
|
df = pd.read_parquet(parquet_path)
|
||||||
|
|
||||||
|
# Build a mapping from episode_index to new task_index for rows in this file
|
||||||
|
episode_indices_in_file = df["episode_index"].unique()
|
||||||
|
ep_to_new_task_idx = {
|
||||||
|
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update task_index column
|
||||||
|
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
|
||||||
|
df.to_parquet(parquet_path, index=False)
|
||||||
|
|
||||||
|
# Update episodes metadata - modify tasks column
|
||||||
|
logging.info("Updating episodes metadata...")
|
||||||
|
episodes_dir = root / "meta" / "episodes"
|
||||||
|
|
||||||
|
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
|
||||||
|
df = pd.read_parquet(parquet_path)
|
||||||
|
|
||||||
|
# Update tasks column
|
||||||
|
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
|
||||||
|
df.to_parquet(parquet_path, index=False)
|
||||||
|
|
||||||
|
# Write new tasks.parquet
|
||||||
|
write_tasks(new_task_df, root)
|
||||||
|
|
||||||
|
# Update info.json
|
||||||
|
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||||
|
write_info(dataset.meta.info, root)
|
||||||
|
|
||||||
|
# Reload metadata to reflect changes
|
||||||
|
dataset.meta.tasks = new_task_df
|
||||||
|
dataset.meta.episodes = load_episodes(root)
|
||||||
|
|
||||||
|
logging.info(f"Tasks: {unique_tasks}")
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def convert_image_to_video_dataset(
|
def convert_image_to_video_dataset(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
Edit LeRobot datasets using various transformation tools.
|
Edit LeRobot datasets using various transformation tools.
|
||||||
|
|
||||||
This script allows you to delete episodes, split datasets, merge datasets,
|
This script allows you to delete episodes, split datasets, merge datasets,
|
||||||
remove features, and convert image datasets to video format.
|
remove features, modify tasks, and convert image datasets to video format.
|
||||||
When new_repo_id is specified, creates a new dataset.
|
When new_repo_id is specified, creates a new dataset.
|
||||||
|
|
||||||
Usage Examples:
|
Usage Examples:
|
||||||
@@ -66,6 +66,25 @@ Remove camera feature:
|
|||||||
--operation.type remove_feature \
|
--operation.type remove_feature \
|
||||||
--operation.feature_names "['observation.images.top']"
|
--operation.feature_names "['observation.images.top']"
|
||||||
|
|
||||||
|
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type modify_tasks \
|
||||||
|
--operation.new_task "Pick up the cube and place it"
|
||||||
|
|
||||||
|
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type modify_tasks \
|
||||||
|
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||||
|
|
||||||
|
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type modify_tasks \
|
||||||
|
--operation.new_task "Default task" \
|
||||||
|
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||||
|
|
||||||
Convert image dataset to video format and save locally:
|
Convert image dataset to video format and save locally:
|
||||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
--repo_id lerobot/pusht_image \
|
--repo_id lerobot/pusht_image \
|
||||||
@@ -100,6 +119,7 @@ from lerobot.datasets.dataset_tools import (
|
|||||||
convert_image_to_video_dataset,
|
convert_image_to_video_dataset,
|
||||||
delete_episodes,
|
delete_episodes,
|
||||||
merge_datasets,
|
merge_datasets,
|
||||||
|
modify_tasks,
|
||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
@@ -132,6 +152,13 @@ class RemoveFeatureConfig:
|
|||||||
feature_names: list[str] | None = None
|
feature_names: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModifyTasksConfig:
|
||||||
|
type: str = "modify_tasks"
|
||||||
|
new_task: str | None = None
|
||||||
|
episode_tasks: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConvertImageToVideoConfig:
|
class ConvertImageToVideoConfig:
|
||||||
type: str = "convert_image_to_video"
|
type: str = "convert_image_to_video"
|
||||||
@@ -151,7 +178,12 @@ class ConvertImageToVideoConfig:
|
|||||||
class EditDatasetConfig:
|
class EditDatasetConfig:
|
||||||
repo_id: str
|
repo_id: str
|
||||||
operation: (
|
operation: (
|
||||||
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig
|
DeleteEpisodesConfig
|
||||||
|
| SplitConfig
|
||||||
|
| MergeConfig
|
||||||
|
| RemoveFeatureConfig
|
||||||
|
| ModifyTasksConfig
|
||||||
|
| ConvertImageToVideoConfig
|
||||||
)
|
)
|
||||||
root: str | None = None
|
root: str | None = None
|
||||||
new_repo_id: str | None = None
|
new_repo_id: str | None = None
|
||||||
@@ -296,6 +328,48 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
|||||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||||
|
if not isinstance(cfg.operation, ModifyTasksConfig):
|
||||||
|
raise ValueError("Operation config must be ModifyTasksConfig")
|
||||||
|
|
||||||
|
new_task = cfg.operation.new_task
|
||||||
|
episode_tasks_raw = cfg.operation.episode_tasks
|
||||||
|
|
||||||
|
if new_task is None and episode_tasks_raw is None:
|
||||||
|
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
|
||||||
|
|
||||||
|
# Warn about in-place modification behavior
|
||||||
|
if cfg.new_repo_id is not None:
|
||||||
|
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
|
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
||||||
|
|
||||||
|
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
|
||||||
|
episode_tasks: dict[int, str] | None = None
|
||||||
|
if episode_tasks_raw is not None:
|
||||||
|
episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()}
|
||||||
|
|
||||||
|
logging.info(f"Modifying tasks in {cfg.repo_id}")
|
||||||
|
if new_task:
|
||||||
|
logging.info(f" Default task: '{new_task}'")
|
||||||
|
if episode_tasks:
|
||||||
|
logging.info(f" Episode-specific tasks: {episode_tasks}")
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(
|
||||||
|
dataset,
|
||||||
|
new_task=new_task,
|
||||||
|
episode_tasks=episode_tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Dataset modified at {dataset.root}")
|
||||||
|
logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}")
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
logging.info(f"Pushing to hub as {cfg.repo_id}")
|
||||||
|
modified_dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||||
# instead of checking isinstance()
|
# instead of checking isinstance()
|
||||||
@@ -371,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
|||||||
handle_merge(cfg)
|
handle_merge(cfg)
|
||||||
elif operation_type == "remove_feature":
|
elif operation_type == "remove_feature":
|
||||||
handle_remove_feature(cfg)
|
handle_remove_feature(cfg)
|
||||||
|
elif operation_type == "modify_tasks":
|
||||||
|
handle_modify_tasks(cfg)
|
||||||
elif operation_type == "convert_image_to_video":
|
elif operation_type == "convert_image_to_video":
|
||||||
handle_convert_image_to_video(cfg)
|
handle_convert_image_to_video(cfg)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown operation type: {operation_type}\n"
|
f"Unknown operation type: {operation_type}\n"
|
||||||
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
|
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import (
|
|||||||
delete_episodes,
|
delete_episodes,
|
||||||
merge_datasets,
|
merge_datasets,
|
||||||
modify_features,
|
modify_features,
|
||||||
|
modify_tasks,
|
||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
@@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
|||||||
assert "reward" in modified_dataset.meta.features
|
assert "reward" in modified_dataset.meta.features
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_single_task_for_all(sample_dataset):
|
||||||
|
"""Test setting a single task for all episodes."""
|
||||||
|
new_task = "Pick up the cube and place it"
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(sample_dataset, new_task=new_task)
|
||||||
|
|
||||||
|
# Verify all episodes have the new task
|
||||||
|
assert len(modified_dataset.meta.tasks) == 1
|
||||||
|
assert new_task in modified_dataset.meta.tasks.index
|
||||||
|
|
||||||
|
# Verify task_index is 0 for all frames (only one task)
|
||||||
|
for i in range(len(modified_dataset)):
|
||||||
|
item = modified_dataset[i]
|
||||||
|
assert item["task_index"].item() == 0
|
||||||
|
assert item["task"] == new_task
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_episode_specific(sample_dataset):
|
||||||
|
"""Test setting different tasks for specific episodes."""
|
||||||
|
episode_tasks = {
|
||||||
|
0: "Task A",
|
||||||
|
1: "Task B",
|
||||||
|
2: "Task A",
|
||||||
|
3: "Task C",
|
||||||
|
4: "Task B",
|
||||||
|
}
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||||
|
|
||||||
|
# Verify correct number of unique tasks
|
||||||
|
unique_tasks = set(episode_tasks.values())
|
||||||
|
assert len(modified_dataset.meta.tasks) == len(unique_tasks)
|
||||||
|
|
||||||
|
# Verify each episode has the correct task
|
||||||
|
for ep_idx, expected_task in episode_tasks.items():
|
||||||
|
ep_data = modified_dataset.meta.episodes[ep_idx]
|
||||||
|
assert ep_data["tasks"][0] == expected_task
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_default_with_overrides(sample_dataset):
|
||||||
|
"""Test setting a default task with specific overrides."""
|
||||||
|
default_task = "Default task"
|
||||||
|
override_task = "Special task"
|
||||||
|
episode_tasks = {2: override_task, 4: override_task}
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(
|
||||||
|
sample_dataset,
|
||||||
|
new_task=default_task,
|
||||||
|
episode_tasks=episode_tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify correct number of unique tasks
|
||||||
|
assert len(modified_dataset.meta.tasks) == 2
|
||||||
|
assert default_task in modified_dataset.meta.tasks.index
|
||||||
|
assert override_task in modified_dataset.meta.tasks.index
|
||||||
|
|
||||||
|
# Verify episodes have correct tasks
|
||||||
|
for ep_idx in range(5):
|
||||||
|
ep_data = modified_dataset.meta.episodes[ep_idx]
|
||||||
|
if ep_idx in episode_tasks:
|
||||||
|
assert ep_data["tasks"][0] == override_task
|
||||||
|
else:
|
||||||
|
assert ep_data["tasks"][0] == default_task
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_no_task_specified(sample_dataset):
|
||||||
|
"""Test error when no task is specified."""
|
||||||
|
with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"):
|
||||||
|
modify_tasks(sample_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_invalid_episode_indices(sample_dataset):
|
||||||
|
"""Test error with invalid episode indices."""
|
||||||
|
with pytest.raises(ValueError, match="Invalid episode indices"):
|
||||||
|
modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_updates_info_json(sample_dataset):
|
||||||
|
"""Test that total_tasks is updated in info.json."""
|
||||||
|
episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"}
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||||
|
|
||||||
|
# Verify total_tasks is updated
|
||||||
|
assert modified_dataset.meta.total_tasks == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_preserves_other_metadata(sample_dataset):
|
||||||
|
"""Test that modifying tasks preserves other metadata."""
|
||||||
|
original_frames = sample_dataset.meta.total_frames
|
||||||
|
original_episodes = sample_dataset.meta.total_episodes
|
||||||
|
original_fps = sample_dataset.meta.fps
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
|
||||||
|
|
||||||
|
# Verify other metadata is preserved
|
||||||
|
assert modified_dataset.meta.total_frames == original_frames
|
||||||
|
assert modified_dataset.meta.total_episodes == original_episodes
|
||||||
|
assert modified_dataset.meta.fps == original_fps
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_task_index_correct(sample_dataset):
|
||||||
|
"""Test that task_index values are correct in data files."""
|
||||||
|
# Create tasks that will have predictable indices (sorted alphabetically)
|
||||||
|
episode_tasks = {
|
||||||
|
0: "Alpha task", # Will be index 0
|
||||||
|
1: "Beta task", # Will be index 1
|
||||||
|
2: "Alpha task", # Will be index 0
|
||||||
|
3: "Gamma task", # Will be index 2
|
||||||
|
4: "Beta task", # Will be index 1
|
||||||
|
}
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||||
|
|
||||||
|
# Verify task indices are correct
|
||||||
|
task_to_expected_idx = {
|
||||||
|
"Alpha task": 0,
|
||||||
|
"Beta task": 1,
|
||||||
|
"Gamma task": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(len(modified_dataset)):
|
||||||
|
item = modified_dataset[i]
|
||||||
|
ep_idx = item["episode_index"].item()
|
||||||
|
expected_task = episode_tasks[ep_idx]
|
||||||
|
expected_idx = task_to_expected_idx[expected_task]
|
||||||
|
assert item["task_index"].item() == expected_idx
|
||||||
|
assert item["task"] == expected_task
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_in_place(sample_dataset):
|
||||||
|
"""Test that modify_tasks modifies the dataset in-place."""
|
||||||
|
original_root = sample_dataset.root
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
|
||||||
|
|
||||||
|
# Verify same instance is returned and root is unchanged
|
||||||
|
assert modified_dataset is sample_dataset
|
||||||
|
assert modified_dataset.root == original_root
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset):
|
||||||
|
"""Test that original tasks are kept when using episode_tasks without new_task."""
|
||||||
|
from lerobot.datasets.utils import load_episodes
|
||||||
|
|
||||||
|
# Ensure episodes metadata is loaded
|
||||||
|
if sample_dataset.meta.episodes is None:
|
||||||
|
sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root)
|
||||||
|
|
||||||
|
# Get original tasks for episodes not being overridden
|
||||||
|
original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0]
|
||||||
|
original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0]
|
||||||
|
|
||||||
|
# Only override episodes 2, 3, 4
|
||||||
|
episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"}
|
||||||
|
|
||||||
|
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||||
|
|
||||||
|
# Verify original tasks are kept for episodes 0 and 1
|
||||||
|
assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0
|
||||||
|
assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1
|
||||||
|
|
||||||
|
# Verify new tasks for overridden episodes
|
||||||
|
assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A"
|
||||||
|
assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B"
|
||||||
|
assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A"
|
||||||
|
|
||||||
|
|
||||||
def test_convert_image_to_video_dataset(tmp_path):
|
def test_convert_image_to_video_dataset(tmp_path):
|
||||||
"""Test converting lerobot/pusht_image dataset to video format."""
|
"""Test converting lerobot/pusht_image dataset to video format."""
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|||||||
Reference in New Issue
Block a user