fix(edit-dataset): translate parent root to exact dataset path LeRobotDataset v2 compatibility

This commit is contained in:
Khalil Meftah
2026-03-02 14:44:52 +01:00
parent 095856b06a
commit 35f66db227
+15 -7
View File
@@ -208,11 +208,17 @@ class InfoConfig(OperationConfig):
class EditDatasetConfig:
repo_id: str
operation: OperationConfig
# Parent cache directory. Each dataset lives at root/{repo_id}. If None, defaults to HF_LEROBOT_HOME.
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
def _resolve_root(root: str | None, repo_id: str) -> Path | None:
"""Translate a parent cache directory into the exact dataset path expected by LeRobotDataset."""
return Path(root) / repo_id if root else None
def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
if new_repo_id:
output_repo_id = new_repo_id
@@ -239,7 +245,7 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
if not cfg.operation.episode_indices:
raise ValueError("episode_indices must be specified for delete_episodes operation")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)
@@ -272,7 +278,7 @@ def handle_split(cfg: EditDatasetConfig) -> None:
"splits dict must be specified with split names as keys and fractions/episode lists as values"
)
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
@@ -299,7 +305,9 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
raise ValueError("repo_id must be specified as the output repository for merged dataset")
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
datasets = [
LeRobotDataset(repo_id, root=_resolve_root(cfg.root, repo_id)) for repo_id in cfg.operation.repo_ids
]
output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
@@ -327,7 +335,7 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
if not cfg.operation.feature_names:
raise ValueError("feature_names must be specified for remove_feature operation")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)
@@ -365,7 +373,7 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
if cfg.new_repo_id is not None:
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
@@ -396,7 +404,7 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
# Note: Parser may create any config type with the right fields, so we access fields directly
# instead of checking isinstance()
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
# Determine output directory and repo_id
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
@@ -473,7 +481,7 @@ def handle_info(cfg: EditDatasetConfig):
if not isinstance(cfg.operation, InfoConfig):
raise ValueError("Operation config must be InfoConfig")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")