mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
rename to fold
This commit is contained in:
+19
-109
@@ -15,19 +15,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Unify/remap tasks in a dataset based on shirt ID.
|
Unify all tasks in a dataset to a single task.
|
||||||
|
|
||||||
This script:
|
This script:
|
||||||
1. Loads a dataset with shirt_id feature
|
1. Loads a dataset
|
||||||
2. Assigns tasks based on shirt ID:
|
2. Sets all task_index to 0 and task description to "fold"
|
||||||
- Shirt IDs 0XX (starting with 0): "Fold the T-shirt properly"
|
|
||||||
- Shirt IDs 1XX, 2XX, etc.: "Layout the t-shirt on the table in an organized manner, then fold the t-shirt properly"
|
|
||||||
3. Updates tasks.parquet and task_index in data files
|
3. Updates tasks.parquet and task_index in data files
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python unify_tasks.py \
|
python examples/openarms/unify_task.py \
|
||||||
--input-repo-id lerobot-data-collection/full_folding_2025-11-30 \
|
--input-repo-id lerobot-data-collection/level1_rac1 \
|
||||||
--output-repo-id lerobot-data-collection/single_task_folding_2025-11-30
|
--output-repo-id lerobot-data-collection/level1_rac1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -50,25 +48,8 @@ from lerobot.datasets.utils import (
|
|||||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||||
|
|
||||||
|
|
||||||
# Task definitions based on shirt ID
|
# Single unified task
|
||||||
TASK_FOLD_ONLY = "Fold the T-shirt properly"
|
UNIFIED_TASK = "fold"
|
||||||
TASK_LAYOUT_AND_FOLD = "Layout the t-shirt on the table in an organized manner, then fold the t-shirt properly"
|
|
||||||
|
|
||||||
|
|
||||||
def get_task_for_shirt_id(shirt_id: int) -> tuple[str, int]:
|
|
||||||
"""Get the task string and index based on shirt ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
shirt_id: The shirt ID (e.g., 2, 112, 219)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (task_string, task_index)
|
|
||||||
- Shirt IDs 0-99 (0XX): task_index=0, fold only
|
|
||||||
- Shirt IDs 100+ (1XX, 2XX, ...): task_index=1, layout and fold
|
|
||||||
"""
|
|
||||||
if shirt_id < 100:
|
|
||||||
return TASK_FOLD_ONLY, 0
|
|
||||||
return TASK_LAYOUT_AND_FOLD, 1
|
|
||||||
|
|
||||||
|
|
||||||
def unify_dataset_tasks(
|
def unify_dataset_tasks(
|
||||||
@@ -78,7 +59,7 @@ def unify_dataset_tasks(
|
|||||||
output_root: Path | None = None,
|
output_root: Path | None = None,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Remap tasks in a dataset based on shirt ID.
|
"""Unify all tasks in a dataset to a single task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_repo_id: Source dataset repository ID.
|
input_repo_id: Source dataset repository ID.
|
||||||
@@ -98,13 +79,6 @@ def unify_dataset_tasks(
|
|||||||
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
||||||
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
||||||
|
|
||||||
# Check if shirt_id feature exists
|
|
||||||
if "shirt_id" not in src_meta.features:
|
|
||||||
raise ValueError(
|
|
||||||
"Dataset does not have 'shirt_id' feature. "
|
|
||||||
"Please add it first using the add_features function."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create output directory
|
# Create output directory
|
||||||
if output_root.exists():
|
if output_root.exists():
|
||||||
logging.warning(f"Output directory {output_root} exists, removing it")
|
logging.warning(f"Output directory {output_root} exists, removing it")
|
||||||
@@ -118,106 +92,44 @@ def unify_dataset_tasks(
|
|||||||
logging.info("Copying videos...")
|
logging.info("Copying videos...")
|
||||||
shutil.copytree(src_videos, output_root / "videos")
|
shutil.copytree(src_videos, output_root / "videos")
|
||||||
|
|
||||||
# Process data files - update task_index based on shirt_id
|
# Process data files - set all task_index to 0
|
||||||
logging.info("Processing data files...")
|
logging.info("Processing data files...")
|
||||||
src_data_dir = input_root / DATA_DIR
|
src_data_dir = input_root / DATA_DIR
|
||||||
dst_data_dir = output_root / DATA_DIR
|
dst_data_dir = output_root / DATA_DIR
|
||||||
dst_data_dir.mkdir(parents=True, exist_ok=True)
|
dst_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Track which tasks are used
|
|
||||||
tasks_used = set()
|
|
||||||
|
|
||||||
for src_parquet in tqdm(sorted(src_data_dir.rglob("*.parquet")), desc="Processing data"):
|
for src_parquet in tqdm(sorted(src_data_dir.rglob("*.parquet")), desc="Processing data"):
|
||||||
rel_path = src_parquet.relative_to(input_root)
|
rel_path = src_parquet.relative_to(input_root)
|
||||||
dst_parquet = output_root / rel_path
|
dst_parquet = output_root / rel_path
|
||||||
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
df = pd.read_parquet(src_parquet)
|
df = pd.read_parquet(src_parquet)
|
||||||
|
df["task_index"] = 0 # All tasks unified to index 0
|
||||||
# Get shirt_id and compute task_index for each row
|
|
||||||
if "shirt_id" in df.columns:
|
|
||||||
# shirt_id might be shape (1,) array or scalar
|
|
||||||
def extract_shirt_id(val):
|
|
||||||
if hasattr(val, "__len__") and len(val) == 1:
|
|
||||||
return int(val[0])
|
|
||||||
return int(val)
|
|
||||||
|
|
||||||
df["task_index"] = df["shirt_id"].apply(
|
|
||||||
lambda x: get_task_for_shirt_id(extract_shirt_id(x))[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Track which tasks are used
|
|
||||||
unique_shirt_ids = df["shirt_id"].apply(extract_shirt_id).unique()
|
|
||||||
for sid in unique_shirt_ids:
|
|
||||||
task_str, _ = get_task_for_shirt_id(sid)
|
|
||||||
tasks_used.add(task_str)
|
|
||||||
else:
|
|
||||||
logging.warning(f"No shirt_id column in {src_parquet}, setting task_index=0")
|
|
||||||
df["task_index"] = 0
|
|
||||||
tasks_used.add(TASK_FOLD_ONLY)
|
|
||||||
|
|
||||||
df.to_parquet(dst_parquet)
|
df.to_parquet(dst_parquet)
|
||||||
|
|
||||||
# Process episodes metadata - update task references
|
# Process episodes metadata - set all tasks to unified task
|
||||||
logging.info("Processing episodes metadata...")
|
logging.info("Processing episodes metadata...")
|
||||||
src_episodes_dir = input_root / "meta" / "episodes"
|
src_episodes_dir = input_root / "meta" / "episodes"
|
||||||
dst_episodes_dir = output_root / "meta" / "episodes"
|
dst_episodes_dir = output_root / "meta" / "episodes"
|
||||||
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
|
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Build episode to shirt_id mapping by reading first frame of each episode
|
|
||||||
episode_shirt_ids = {}
|
|
||||||
for src_parquet in sorted(src_data_dir.rglob("*.parquet")):
|
|
||||||
df = pd.read_parquet(src_parquet)
|
|
||||||
if "shirt_id" in df.columns and "episode_index" in df.columns:
|
|
||||||
for ep_idx in df["episode_index"].unique():
|
|
||||||
if ep_idx not in episode_shirt_ids:
|
|
||||||
ep_data = df[df["episode_index"] == ep_idx].iloc[0]
|
|
||||||
shirt_val = ep_data["shirt_id"]
|
|
||||||
if hasattr(shirt_val, "__len__") and len(shirt_val) == 1:
|
|
||||||
episode_shirt_ids[int(ep_idx)] = int(shirt_val[0])
|
|
||||||
else:
|
|
||||||
episode_shirt_ids[int(ep_idx)] = int(shirt_val)
|
|
||||||
|
|
||||||
for src_parquet in tqdm(sorted(src_episodes_dir.rglob("*.parquet")), desc="Processing episodes"):
|
for src_parquet in tqdm(sorted(src_episodes_dir.rglob("*.parquet")), desc="Processing episodes"):
|
||||||
rel_path = src_parquet.relative_to(src_episodes_dir)
|
rel_path = src_parquet.relative_to(src_episodes_dir)
|
||||||
dst_parquet = dst_episodes_dir / rel_path
|
dst_parquet = dst_episodes_dir / rel_path
|
||||||
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
df = pd.read_parquet(src_parquet)
|
df = pd.read_parquet(src_parquet)
|
||||||
|
df["tasks"] = [[UNIFIED_TASK]] * len(df) # All episodes get the unified task
|
||||||
# Update tasks column based on episode's shirt_id
|
|
||||||
new_tasks_col = []
|
|
||||||
for idx, row in df.iterrows():
|
|
||||||
ep_idx = int(row["episode_index"])
|
|
||||||
shirt_id = episode_shirt_ids.get(ep_idx, 0)
|
|
||||||
task_str, _ = get_task_for_shirt_id(shirt_id)
|
|
||||||
new_tasks_col.append([task_str])
|
|
||||||
|
|
||||||
df["tasks"] = new_tasks_col
|
|
||||||
df.to_parquet(dst_parquet)
|
df.to_parquet(dst_parquet)
|
||||||
|
|
||||||
# Create new tasks.parquet with the tasks that are actually used
|
# Create new tasks.parquet with single task
|
||||||
logging.info(f"Creating tasks: {tasks_used}")
|
logging.info(f"Creating single task: {UNIFIED_TASK}")
|
||||||
task_list = sorted(tasks_used) # Sort for consistent ordering
|
new_tasks = pd.DataFrame({"task_index": [0]}, index=[UNIFIED_TASK])
|
||||||
# Ensure TASK_FOLD_ONLY is index 0 and TASK_LAYOUT_AND_FOLD is index 1
|
|
||||||
if TASK_FOLD_ONLY in task_list and TASK_LAYOUT_AND_FOLD in task_list:
|
|
||||||
task_list = [TASK_FOLD_ONLY, TASK_LAYOUT_AND_FOLD]
|
|
||||||
elif TASK_FOLD_ONLY in task_list:
|
|
||||||
task_list = [TASK_FOLD_ONLY]
|
|
||||||
elif TASK_LAYOUT_AND_FOLD in task_list:
|
|
||||||
# If only layout task is used, it should still be index 1 for consistency
|
|
||||||
# But we need index 0 to exist, so include both
|
|
||||||
task_list = [TASK_FOLD_ONLY, TASK_LAYOUT_AND_FOLD]
|
|
||||||
|
|
||||||
new_tasks = pd.DataFrame(
|
|
||||||
{"task_index": list(range(len(task_list)))},
|
|
||||||
index=task_list
|
|
||||||
)
|
|
||||||
write_tasks(new_tasks, output_root)
|
write_tasks(new_tasks, output_root)
|
||||||
|
|
||||||
# Update info.json
|
# Update info.json
|
||||||
new_info = src_meta.info.copy()
|
new_info = src_meta.info.copy()
|
||||||
new_info["total_tasks"] = len(task_list)
|
new_info["total_tasks"] = 1
|
||||||
write_info(new_info, output_root)
|
write_info(new_info, output_root)
|
||||||
|
|
||||||
# Copy stats.json (unchanged)
|
# Copy stats.json (unchanged)
|
||||||
@@ -225,7 +137,7 @@ def unify_dataset_tasks(
|
|||||||
write_stats(src_meta.stats, output_root)
|
write_stats(src_meta.stats, output_root)
|
||||||
|
|
||||||
logging.info(f"Dataset saved to {output_root}")
|
logging.info(f"Dataset saved to {output_root}")
|
||||||
logging.info(f"Tasks: {task_list}")
|
logging.info(f"Task: {UNIFIED_TASK}")
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
@@ -238,9 +150,7 @@ def unify_dataset_tasks(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Remap tasks in a dataset based on shirt ID. "
|
description="Unify all tasks in a dataset to a single task 'fold'."
|
||||||
"Shirt IDs 0-99 get 'Fold the T-shirt properly', "
|
|
||||||
"Shirt IDs 100+ get 'Layout and fold' task."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user