fix(dataset): no default overwrite in lerobot tool recompute stats (#3452)

This commit is contained in:
Steven Palma
2026-04-24 15:07:19 +02:00
committed by GitHub
parent 587aa82021
commit 580d818aa9
+67 -7
View File
@@ -150,11 +150,24 @@ Show dataset information without feature details:
--operation.type info \ --operation.type info \
--operation.show_features false --operation.show_features false
Recompute dataset statistics: Recompute dataset statistics (saves to lerobot/pusht_recomputed_stats by default):
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht \ --repo_id lerobot/pusht \
--operation.type recompute_stats --operation.type recompute_stats
Recompute stats and save to a specific new repo_id:
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_new_stats \
--operation.type recompute_stats
Recompute stats in-place (overwrites original dataset stats):
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht \
--operation.type recompute_stats \
--operation.overwrite true
Recompute stats for relative actions and push to hub: Recompute stats for relative actions and push to hub:
lerobot-edit-dataset \ lerobot-edit-dataset \
--repo_id lerobot/pusht \ --repo_id lerobot/pusht \
@@ -256,6 +269,7 @@ class RecomputeStatsConfig(OperationConfig):
relative_exclude_joints: list[str] | None = None relative_exclude_joints: list[str] | None = None
chunk_size: int = 50 chunk_size: int = 50
num_workers: int = 0 num_workers: int = 0
overwrite: bool = False
@OperationConfig.register_subclass("info") @OperationConfig.register_subclass("info")
@@ -280,16 +294,30 @@ class EditDatasetConfig:
push_to_hub: bool = False push_to_hub: bool = False
def _resolve_io_paths(
repo_id: str,
new_repo_id: str | None,
root: Path | str | None,
new_root: Path | str | None,
default_new_repo_id: str | None = None,
) -> tuple[str, Path, Path]:
"""Resolve input/output paths and repo_id for dataset operations.
Returns (output_repo_id, input_path, output_path) with resolved (symlink-safe) paths.
"""
input_path = (Path(root) if root else HF_LEROBOT_HOME / repo_id).resolve()
output_repo_id = new_repo_id or default_new_repo_id or repo_id
output_path = (Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id).resolve()
return output_repo_id, input_path, output_path
def get_output_path( def get_output_path(
repo_id: str, repo_id: str,
new_repo_id: str | None, new_repo_id: str | None,
root: Path | str | None, root: Path | str | None,
new_root: Path | str | None, new_root: Path | str | None,
) -> tuple[str, Path]: ) -> tuple[str, Path]:
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id output_repo_id, input_path, output_path = _resolve_io_paths(repo_id, new_repo_id, root, new_root)
output_repo_id = new_repo_id if new_repo_id else repo_id
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
# In case of in-place modification, create a backup of the original dataset (if it exists) # In case of in-place modification, create a backup of the original dataset (if it exists)
if output_path == input_path: if output_path == input_path:
@@ -557,7 +585,39 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, RecomputeStatsConfig): if not isinstance(cfg.operation, RecomputeStatsConfig):
raise ValueError("Operation config must be RecomputeStatsConfig") raise ValueError("Operation config must be RecomputeStatsConfig")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) # Determine whether this is an in-place operation
output_repo_id, input_root, output_root = _resolve_io_paths(
cfg.repo_id,
cfg.new_repo_id,
cfg.root,
cfg.new_root,
default_new_repo_id=f"{cfg.repo_id}_recomputed_stats",
)
in_place = output_root == input_root
if in_place and not cfg.operation.overwrite:
raise ValueError(
f"recompute_stats would overwrite the dataset in-place at {input_root}. "
"Pass --operation.overwrite true to allow in-place modification, "
"or use --new_repo_id / --new_root to write to a different location. "
f"Default output repo_id when neither is set: '{cfg.repo_id}_recomputed_stats'."
)
if in_place:
logging.warning(
f"Overwriting dataset stats in-place at {input_root}. The original stats will be lost."
)
dataset = LeRobotDataset(cfg.repo_id, root=input_root)
else:
logging.info(f"Copying dataset from {input_root} to {output_root}")
if output_root.exists():
backup_path = output_root.with_name(output_root.name + "_old")
logging.warning(f"Output directory {output_root} already exists. Moving to {backup_path}")
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(output_root, backup_path)
shutil.copytree(input_root, output_root)
dataset = LeRobotDataset(output_repo_id, root=output_root)
logging.info(f"Recomputing stats for {cfg.repo_id}") logging.info(f"Recomputing stats for {cfg.repo_id}")
if cfg.operation.relative_action: if cfg.operation.relative_action:
@@ -578,7 +638,7 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
logging.info(f"Stats written to {dataset.root}") logging.info(f"Stats written to {dataset.root}")
if cfg.push_to_hub: if cfg.push_to_hub:
logging.info(f"Pushing to hub as {dataset.meta.repo_id}...") logging.info(f"Pushing to hub as {dataset.repo_id}...")
dataset.push_to_hub() dataset.push_to_hub()