mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(dataset): no default overwrite in lerobot tool recompute stats (#3452)
This commit is contained in:
@@ -150,11 +150,24 @@ Show dataset information without feature details:
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Recompute dataset statistics:
|
||||
Recompute dataset statistics (saves to lerobot/pusht_recomputed_stats by default):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type recompute_stats
|
||||
|
||||
Recompute stats and save to a specific new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_new_stats \
|
||||
--operation.type recompute_stats
|
||||
|
||||
Recompute stats in-place (overwrites original dataset stats):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht \
|
||||
--operation.type recompute_stats \
|
||||
--operation.overwrite true
|
||||
|
||||
Recompute stats for relative actions and push to hub:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
@@ -256,6 +269,7 @@ class RecomputeStatsConfig(OperationConfig):
|
||||
relative_exclude_joints: list[str] | None = None
|
||||
chunk_size: int = 50
|
||||
num_workers: int = 0
|
||||
overwrite: bool = False
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@@ -280,16 +294,30 @@ class EditDatasetConfig:
|
||||
push_to_hub: bool = False
|
||||
|
||||
|
||||
def _resolve_io_paths(
|
||||
repo_id: str,
|
||||
new_repo_id: str | None,
|
||||
root: Path | str | None,
|
||||
new_root: Path | str | None,
|
||||
default_new_repo_id: str | None = None,
|
||||
) -> tuple[str, Path, Path]:
|
||||
"""Resolve input/output paths and repo_id for dataset operations.
|
||||
|
||||
Returns (output_repo_id, input_path, output_path) with resolved (symlink-safe) paths.
|
||||
"""
|
||||
input_path = (Path(root) if root else HF_LEROBOT_HOME / repo_id).resolve()
|
||||
output_repo_id = new_repo_id or default_new_repo_id or repo_id
|
||||
output_path = (Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id).resolve()
|
||||
return output_repo_id, input_path, output_path
|
||||
|
||||
|
||||
def get_output_path(
|
||||
repo_id: str,
|
||||
new_repo_id: str | None,
|
||||
root: Path | str | None,
|
||||
new_root: Path | str | None,
|
||||
) -> tuple[str, Path]:
|
||||
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
output_repo_id = new_repo_id if new_repo_id else repo_id
|
||||
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
|
||||
output_repo_id, input_path, output_path = _resolve_io_paths(repo_id, new_repo_id, root, new_root)
|
||||
|
||||
# In case of in-place modification, create a backup of the original dataset (if it exists)
|
||||
if output_path == input_path:
|
||||
@@ -557,7 +585,39 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, RecomputeStatsConfig):
|
||||
raise ValueError("Operation config must be RecomputeStatsConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
# Determine whether this is an in-place operation
|
||||
output_repo_id, input_root, output_root = _resolve_io_paths(
|
||||
cfg.repo_id,
|
||||
cfg.new_repo_id,
|
||||
cfg.root,
|
||||
cfg.new_root,
|
||||
default_new_repo_id=f"{cfg.repo_id}_recomputed_stats",
|
||||
)
|
||||
in_place = output_root == input_root
|
||||
|
||||
if in_place and not cfg.operation.overwrite:
|
||||
raise ValueError(
|
||||
f"recompute_stats would overwrite the dataset in-place at {input_root}. "
|
||||
"Pass --operation.overwrite true to allow in-place modification, "
|
||||
"or use --new_repo_id / --new_root to write to a different location. "
|
||||
f"Default output repo_id when neither is set: '{cfg.repo_id}_recomputed_stats'."
|
||||
)
|
||||
|
||||
if in_place:
|
||||
logging.warning(
|
||||
f"Overwriting dataset stats in-place at {input_root}. The original stats will be lost."
|
||||
)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=input_root)
|
||||
else:
|
||||
logging.info(f"Copying dataset from {input_root} to {output_root}")
|
||||
if output_root.exists():
|
||||
backup_path = output_root.with_name(output_root.name + "_old")
|
||||
logging.warning(f"Output directory {output_root} already exists. Moving to {backup_path}")
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(output_root, backup_path)
|
||||
shutil.copytree(input_root, output_root)
|
||||
dataset = LeRobotDataset(output_repo_id, root=output_root)
|
||||
|
||||
logging.info(f"Recomputing stats for {cfg.repo_id}")
|
||||
if cfg.operation.relative_action:
|
||||
@@ -578,7 +638,7 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
||||
logging.info(f"Stats written to {dataset.root}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {dataset.meta.repo_id}...")
|
||||
logging.info(f"Pushing to hub as {dataset.repo_id}...")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user