mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(edit-dataset): translate parent root to exact dataset path LeRobotDataset v2 compatibility
This commit is contained in:
@@ -208,11 +208,17 @@ class InfoConfig(OperationConfig):
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: OperationConfig
|
||||
# Parent cache directory. Each dataset lives at root/{repo_id}. If None, defaults to HF_LEROBOT_HOME.
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
push_to_hub: bool = False
|
||||
|
||||
|
||||
def _resolve_root(root: str | None, repo_id: str) -> Path | None:
|
||||
"""Translate a parent cache directory into the exact dataset path expected by LeRobotDataset."""
|
||||
return Path(root) / repo_id if root else None
|
||||
|
||||
|
||||
def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
|
||||
if new_repo_id:
|
||||
output_repo_id = new_repo_id
|
||||
@@ -239,7 +245,7 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||
if not cfg.operation.episode_indices:
|
||||
raise ValueError("episode_indices must be specified for delete_episodes operation")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
@@ -272,7 +278,7 @@ def handle_split(cfg: EditDatasetConfig) -> None:
|
||||
"splits dict must be specified with split names as keys and fractions/episode lists as values"
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
||||
|
||||
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
|
||||
split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
|
||||
@@ -299,7 +305,9 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||
raise ValueError("repo_id must be specified as the output repository for merged dataset")
|
||||
|
||||
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
|
||||
datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
|
||||
datasets = [
|
||||
LeRobotDataset(repo_id, root=_resolve_root(cfg.root, repo_id)) for repo_id in cfg.operation.repo_ids
|
||||
]
|
||||
|
||||
output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
|
||||
|
||||
@@ -327,7 +335,7 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
if not cfg.operation.feature_names:
|
||||
raise ValueError("feature_names must be specified for remove_feature operation")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
@@ -365,7 +373,7 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if cfg.new_repo_id is not None:
|
||||
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
||||
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
||||
|
||||
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
|
||||
@@ -396,7 +404,7 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
||||
|
||||
# Determine output directory and repo_id
|
||||
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
|
||||
@@ -473,7 +481,7 @@ def handle_info(cfg: EditDatasetConfig):
|
||||
if not isinstance(cfg.operation, InfoConfig):
|
||||
raise ValueError("Operation config must be InfoConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
||||
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
||||
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
||||
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
||||
|
||||
Reference in New Issue
Block a user