mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(dataset): no default overwrite in lerobot tool recompute stats (#3452)
This commit is contained in:
@@ -150,11 +150,24 @@ Show dataset information without feature details:
|
|||||||
--operation.type info \
|
--operation.type info \
|
||||||
--operation.show_features false
|
--operation.show_features false
|
||||||
|
|
||||||
Recompute dataset statistics:
|
Recompute dataset statistics (saves to lerobot/pusht_recomputed_stats by default):
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id lerobot/pusht \
|
--repo_id lerobot/pusht \
|
||||||
--operation.type recompute_stats
|
--operation.type recompute_stats
|
||||||
|
|
||||||
|
Recompute stats and save to a specific new repo_id:
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--new_repo_id lerobot/pusht_new_stats \
|
||||||
|
--operation.type recompute_stats
|
||||||
|
|
||||||
|
Recompute stats in-place (overwrites original dataset stats):
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--new_repo_id lerobot/pusht \
|
||||||
|
--operation.type recompute_stats \
|
||||||
|
--operation.overwrite true
|
||||||
|
|
||||||
Recompute stats for relative actions and push to hub:
|
Recompute stats for relative actions and push to hub:
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id lerobot/pusht \
|
--repo_id lerobot/pusht \
|
||||||
@@ -256,6 +269,7 @@ class RecomputeStatsConfig(OperationConfig):
|
|||||||
relative_exclude_joints: list[str] | None = None
|
relative_exclude_joints: list[str] | None = None
|
||||||
chunk_size: int = 50
|
chunk_size: int = 50
|
||||||
num_workers: int = 0
|
num_workers: int = 0
|
||||||
|
overwrite: bool = False
|
||||||
|
|
||||||
|
|
||||||
@OperationConfig.register_subclass("info")
|
@OperationConfig.register_subclass("info")
|
||||||
@@ -280,16 +294,30 @@ class EditDatasetConfig:
|
|||||||
push_to_hub: bool = False
|
push_to_hub: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_io_paths(
|
||||||
|
repo_id: str,
|
||||||
|
new_repo_id: str | None,
|
||||||
|
root: Path | str | None,
|
||||||
|
new_root: Path | str | None,
|
||||||
|
default_new_repo_id: str | None = None,
|
||||||
|
) -> tuple[str, Path, Path]:
|
||||||
|
"""Resolve input/output paths and repo_id for dataset operations.
|
||||||
|
|
||||||
|
Returns (output_repo_id, input_path, output_path) with resolved (symlink-safe) paths.
|
||||||
|
"""
|
||||||
|
input_path = (Path(root) if root else HF_LEROBOT_HOME / repo_id).resolve()
|
||||||
|
output_repo_id = new_repo_id or default_new_repo_id or repo_id
|
||||||
|
output_path = (Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id).resolve()
|
||||||
|
return output_repo_id, input_path, output_path
|
||||||
|
|
||||||
|
|
||||||
def get_output_path(
|
def get_output_path(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
new_repo_id: str | None,
|
new_repo_id: str | None,
|
||||||
root: Path | str | None,
|
root: Path | str | None,
|
||||||
new_root: Path | str | None,
|
new_root: Path | str | None,
|
||||||
) -> tuple[str, Path]:
|
) -> tuple[str, Path]:
|
||||||
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
output_repo_id, input_path, output_path = _resolve_io_paths(repo_id, new_repo_id, root, new_root)
|
||||||
|
|
||||||
output_repo_id = new_repo_id if new_repo_id else repo_id
|
|
||||||
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
|
|
||||||
|
|
||||||
# In case of in-place modification, create a backup of the original dataset (if it exists)
|
# In case of in-place modification, create a backup of the original dataset (if it exists)
|
||||||
if output_path == input_path:
|
if output_path == input_path:
|
||||||
@@ -557,7 +585,39 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
|||||||
if not isinstance(cfg.operation, RecomputeStatsConfig):
|
if not isinstance(cfg.operation, RecomputeStatsConfig):
|
||||||
raise ValueError("Operation config must be RecomputeStatsConfig")
|
raise ValueError("Operation config must be RecomputeStatsConfig")
|
||||||
|
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
# Determine whether this is an in-place operation
|
||||||
|
output_repo_id, input_root, output_root = _resolve_io_paths(
|
||||||
|
cfg.repo_id,
|
||||||
|
cfg.new_repo_id,
|
||||||
|
cfg.root,
|
||||||
|
cfg.new_root,
|
||||||
|
default_new_repo_id=f"{cfg.repo_id}_recomputed_stats",
|
||||||
|
)
|
||||||
|
in_place = output_root == input_root
|
||||||
|
|
||||||
|
if in_place and not cfg.operation.overwrite:
|
||||||
|
raise ValueError(
|
||||||
|
f"recompute_stats would overwrite the dataset in-place at {input_root}. "
|
||||||
|
"Pass --operation.overwrite true to allow in-place modification, "
|
||||||
|
"or use --new_repo_id / --new_root to write to a different location. "
|
||||||
|
f"Default output repo_id when neither is set: '{cfg.repo_id}_recomputed_stats'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if in_place:
|
||||||
|
logging.warning(
|
||||||
|
f"Overwriting dataset stats in-place at {input_root}. The original stats will be lost."
|
||||||
|
)
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=input_root)
|
||||||
|
else:
|
||||||
|
logging.info(f"Copying dataset from {input_root} to {output_root}")
|
||||||
|
if output_root.exists():
|
||||||
|
backup_path = output_root.with_name(output_root.name + "_old")
|
||||||
|
logging.warning(f"Output directory {output_root} already exists. Moving to {backup_path}")
|
||||||
|
if backup_path.exists():
|
||||||
|
shutil.rmtree(backup_path)
|
||||||
|
shutil.move(output_root, backup_path)
|
||||||
|
shutil.copytree(input_root, output_root)
|
||||||
|
dataset = LeRobotDataset(output_repo_id, root=output_root)
|
||||||
|
|
||||||
logging.info(f"Recomputing stats for {cfg.repo_id}")
|
logging.info(f"Recomputing stats for {cfg.repo_id}")
|
||||||
if cfg.operation.relative_action:
|
if cfg.operation.relative_action:
|
||||||
@@ -578,7 +638,7 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
|||||||
logging.info(f"Stats written to {dataset.root}")
|
logging.info(f"Stats written to {dataset.root}")
|
||||||
|
|
||||||
if cfg.push_to_hub:
|
if cfg.push_to_hub:
|
||||||
logging.info(f"Pushing to hub as {dataset.meta.repo_id}...")
|
logging.info(f"Pushing to hub as {dataset.repo_id}...")
|
||||||
dataset.push_to_hub()
|
dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user