diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 0cfb34325..a708d37a3 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -150,11 +150,24 @@ Show dataset information without feature details: --operation.type info \ --operation.show_features false -Recompute dataset statistics: +Recompute dataset statistics (saves to lerobot/pusht_recomputed_stats by default): lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type recompute_stats +Recompute stats and save to a specific new repo_id: + lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_new_stats \ + --operation.type recompute_stats + +Recompute stats in-place (overwrites original dataset stats): + lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht \ + --operation.type recompute_stats \ + --operation.overwrite true + Recompute stats for relative actions and push to hub: lerobot-edit-dataset \ --repo_id lerobot/pusht \ @@ -256,6 +269,7 @@ class RecomputeStatsConfig(OperationConfig): relative_exclude_joints: list[str] | None = None chunk_size: int = 50 num_workers: int = 0 + overwrite: bool = False @OperationConfig.register_subclass("info") @@ -280,16 +294,30 @@ class EditDatasetConfig: push_to_hub: bool = False +def _resolve_io_paths( + repo_id: str, + new_repo_id: str | None, + root: Path | str | None, + new_root: Path | str | None, + default_new_repo_id: str | None = None, +) -> tuple[str, Path, Path]: + """Resolve input/output paths and repo_id for dataset operations. + + Returns (output_repo_id, input_path, output_path) with resolved (symlink-safe) paths. + """ + input_path = (Path(root) if root else HF_LEROBOT_HOME / repo_id).resolve() + output_repo_id = new_repo_id or default_new_repo_id or repo_id + output_path = (Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id).resolve() + return output_repo_id, input_path, output_path + + def get_output_path( repo_id: str, new_repo_id: str | None, root: Path | str | None, new_root: Path | str | None, ) -> tuple[str, Path]: - input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id - - output_repo_id = new_repo_id if new_repo_id else repo_id - output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id + output_repo_id, input_path, output_path = _resolve_io_paths(repo_id, new_repo_id, root, new_root) # In case of in-place modification, create a backup of the original dataset (if it exists) if output_path == input_path: @@ -557,7 +585,39 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None: if not isinstance(cfg.operation, RecomputeStatsConfig): raise ValueError("Operation config must be RecomputeStatsConfig") - dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + # Determine whether this is an in-place operation + output_repo_id, input_root, output_root = _resolve_io_paths( + cfg.repo_id, + cfg.new_repo_id, + cfg.root, + cfg.new_root, + default_new_repo_id=f"{cfg.repo_id}_recomputed_stats", + ) + in_place = output_root == input_root + + if in_place and not cfg.operation.overwrite: + raise ValueError( + f"recompute_stats would overwrite the dataset in-place at {input_root}. " + "Pass --operation.overwrite true to allow in-place modification, " + "or use --new_repo_id / --new_root to write to a different location. " + f"Default output repo_id when neither is set: '{cfg.repo_id}_recomputed_stats'." + ) + + if in_place: + logging.warning( + f"Overwriting dataset stats in-place at {input_root}. The original stats will be lost." + ) + dataset = LeRobotDataset(cfg.repo_id, root=input_root) + else: + logging.info(f"Copying dataset from {input_root} to {output_root}") + if output_root.exists(): + backup_path = output_root.with_name(output_root.name + "_old") + logging.warning(f"Output directory {output_root} already exists. Moving to {backup_path}") + if backup_path.exists(): + shutil.rmtree(backup_path) + shutil.move(output_root, backup_path) + shutil.copytree(input_root, output_root) + dataset = LeRobotDataset(output_repo_id, root=output_root) logging.info(f"Recomputing stats for {cfg.repo_id}") if cfg.operation.relative_action: @@ -578,7 +638,7 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None: logging.info(f"Stats written to {dataset.root}") if cfg.push_to_hub: - logging.info(f"Pushing to hub as {dataset.meta.repo_id}...") + logging.info(f"Pushing to hub as {dataset.repo_id}...") dataset.push_to_hub()