mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6906178b39 | |||
| cbc8bfb2e6 | |||
| 0d1be72dc8 | |||
| 96b7c212c4 | |||
| 4303b3c930 | |||
| 63dca86df8 | |||
| 8a0cc3d664 | |||
| 8bb8ed4803 |
@@ -48,7 +48,7 @@ python -m lerobot.async_inference.robot_client \
|
|||||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
--policy_device=mps \ # POLICY: the device to run the policy on, on the server (cuda, mps, xpu, cpu)
|
||||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||||
|
|||||||
@@ -170,13 +170,13 @@ Once you can drive the robot well, you can start recording data to train AI mode
|
|||||||
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
```
|
```
|
||||||
|
|
||||||
Store your Hugging Face username:
|
Store your Hugging Face username:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||||
echo $HF_USER
|
echo $HF_USER
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -155,10 +155,10 @@ Upload your repository to Hugging Face:
|
|||||||
pip install huggingface_hub
|
pip install huggingface_hub
|
||||||
|
|
||||||
# Login to Hugging Face
|
# Login to Hugging Face
|
||||||
huggingface-cli login
|
hf auth login
|
||||||
|
|
||||||
# Create a new repository
|
# Create a new repository
|
||||||
huggingface-cli repo create my-custom-env --type space --org my-org
|
hf repo create my-org/my-custom-env
|
||||||
|
|
||||||
# Initialize git and push
|
# Initialize git and push
|
||||||
git init
|
git init
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
|||||||
Add your token to the CLI by running this command:
|
Add your token to the CLI by running this command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
```
|
```
|
||||||
|
|
||||||
Then store your Hugging Face repository name in a variable:
|
Then store your Hugging Face repository name in a variable:
|
||||||
@@ -327,7 +327,7 @@ You can look for other LeRobot datasets on the hub by searching for `LeRobot` [t
|
|||||||
You can also push your local dataset to the Hub manually, running:
|
You can also push your local dataset to the Hub manually, running:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
hf upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Record function
|
#### Record function
|
||||||
@@ -491,7 +491,7 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
|
|||||||
Once training is done, upload the latest checkpoint with:
|
Once training is done, upload the latest checkpoint with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli upload ${HF_USER}/act_so101_test \
|
hf upload ${HF_USER}/act_so101_test \
|
||||||
outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -499,7 +499,7 @@ You can also upload intermediate checkpoints with:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
CKPT=010000
|
CKPT=010000
|
||||||
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||||
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
|
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -279,13 +279,13 @@ We use the Hugging Face hub features for uploading your dataset. If you haven't
|
|||||||
Add your token to the CLI by running this command:
|
Add your token to the CLI by running this command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
hf auth login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
```
|
```
|
||||||
|
|
||||||
Then store your Hugging Face repository name in a variable:
|
Then store your Hugging Face repository name in a variable:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||||
echo $HF_USER
|
echo $HF_USER
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from lerobot.datasets.compute_stats import aggregate_stats
|
from lerobot.datasets.compute_stats import aggregate_stats
|
||||||
@@ -35,7 +37,6 @@ from lerobot.datasets.utils import (
|
|||||||
get_file_size_in_mb,
|
get_file_size_in_mb,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
to_parquet_with_hf_images,
|
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
write_stats,
|
||||||
@@ -80,28 +81,41 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|||||||
return fps, robot_type, features
|
return fps, robot_type, features
|
||||||
|
|
||||||
|
|
||||||
def update_data_df(df, src_meta, dst_meta):
|
def update_data_table(table: pa.Table, src_meta, dst_meta) -> pa.Table:
|
||||||
"""Updates a data DataFrame with new indices and task mappings for aggregation.
|
"""Updates a pyarrow Table with new indices and task mappings for aggregation.
|
||||||
|
|
||||||
Adjusts episode indices, frame indices, and task indices to account for
|
Adjusts episode indices, frame indices, and task indices to account for
|
||||||
previously aggregated data in the destination dataset.
|
previously aggregated data in the destination dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: DataFrame containing the data to be updated.
|
table: pyarrow Table containing the data to be updated.
|
||||||
src_meta: Source dataset metadata.
|
src_meta: Source dataset metadata.
|
||||||
dst_meta: Destination dataset metadata.
|
dst_meta: Destination dataset metadata.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
pa.Table: Updated Table with adjusted indices.
|
||||||
"""
|
"""
|
||||||
|
ep_offset = dst_meta.info["total_episodes"]
|
||||||
|
idx_offset = dst_meta.info["total_frames"]
|
||||||
|
|
||||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
ep_col = table.column("episode_index")
|
||||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
new_ep = pa.array([v + ep_offset for v in ep_col.to_pylist()], type=ep_col.type)
|
||||||
|
table = table.set_column(table.column_names.index("episode_index"), "episode_index", new_ep)
|
||||||
|
|
||||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
idx_col = table.column("index")
|
||||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
new_idx = pa.array([v + idx_offset for v in idx_col.to_pylist()], type=idx_col.type)
|
||||||
|
table = table.set_column(table.column_names.index("index"), "index", new_idx)
|
||||||
|
|
||||||
return df
|
old_task_indices = table.column("task_index").to_pylist()
|
||||||
|
src_task_names = src_meta.tasks.index.take(old_task_indices)
|
||||||
|
new_task_indices = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy().tolist()
|
||||||
|
table = table.set_column(
|
||||||
|
table.column_names.index("task_index"),
|
||||||
|
"task_index",
|
||||||
|
pa.array(new_task_indices, type=table.schema.field("task_index").type),
|
||||||
|
)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
def update_meta_data(
|
def update_meta_data(
|
||||||
@@ -289,7 +303,9 @@ def aggregate_datasets(
|
|||||||
|
|
||||||
logging.info("Find all tasks")
|
logging.info("Find all tasks")
|
||||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||||
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
dst_meta.tasks = pd.DataFrame(
|
||||||
|
{"task_index": range(len(unique_tasks))}, index=pd.Index(unique_tasks, name="task")
|
||||||
|
)
|
||||||
|
|
||||||
meta_idx = {"chunk": 0, "file": 0}
|
meta_idx = {"chunk": 0, "file": 0}
|
||||||
data_idx = {"chunk": 0, "file": 0}
|
data_idx = {"chunk": 0, "file": 0}
|
||||||
@@ -466,18 +482,13 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||||
)
|
)
|
||||||
if contains_images:
|
table = pq.read_table(src_path)
|
||||||
# Use HuggingFace datasets to read source data to preserve image format
|
table = update_data_table(table, src_meta, dst_meta)
|
||||||
src_ds = datasets.Dataset.from_parquet(str(src_path))
|
|
||||||
df = src_ds.to_pandas()
|
|
||||||
else:
|
|
||||||
df = pd.read_parquet(src_path)
|
|
||||||
df = update_data_df(df, src_meta, dst_meta)
|
|
||||||
|
|
||||||
# Write data and get the actual destination file it was written to
|
# Write data and get the actual destination file it was written to
|
||||||
# This avoids duplicating the rotation logic here
|
# This avoids duplicating the rotation logic here
|
||||||
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
||||||
df,
|
table,
|
||||||
src_path,
|
src_path,
|
||||||
data_idx,
|
data_idx,
|
||||||
data_files_size_in_mb,
|
data_files_size_in_mb,
|
||||||
@@ -552,8 +563,16 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
return meta_idx
|
return meta_idx
|
||||||
|
|
||||||
|
|
||||||
|
def _write_table_with_hf_images(
|
||||||
|
table: pa.Table, path: Path, features: datasets.Features | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Write a pyarrow Table to parquet with proper HF image encoding."""
|
||||||
|
ds = datasets.Dataset.from_dict(table.to_pydict(), features=features)
|
||||||
|
ds.to_parquet(path)
|
||||||
|
|
||||||
|
|
||||||
def append_or_create_parquet_file(
|
def append_or_create_parquet_file(
|
||||||
df: pd.DataFrame,
|
data: pd.DataFrame | pa.Table,
|
||||||
src_path: Path,
|
src_path: Path,
|
||||||
idx: dict[str, int],
|
idx: dict[str, int],
|
||||||
max_mb: float,
|
max_mb: float,
|
||||||
@@ -569,7 +588,7 @@ def append_or_create_parquet_file(
|
|||||||
from becoming too large. Handles both regular parquet files and those containing images.
|
from becoming too large. Handles both regular parquet files and those containing images.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
df: DataFrame to write to the parquet file.
|
data: Data to write, as a pandas DataFrame or pyarrow Table.
|
||||||
src_path: Path to the source file (used for size estimation).
|
src_path: Path to the source file (used for size estimation).
|
||||||
idx: Dictionary containing current 'chunk' and 'file' indices.
|
idx: Dictionary containing current 'chunk' and 'file' indices.
|
||||||
max_mb: Maximum allowed file size in MB before rotation.
|
max_mb: Maximum allowed file size in MB before rotation.
|
||||||
@@ -583,15 +602,17 @@ def append_or_create_parquet_file(
|
|||||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||||
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
||||||
"""
|
"""
|
||||||
|
table = data if isinstance(data, pa.Table) else pa.Table.from_pandas(data)
|
||||||
|
|
||||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||||
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||||
|
|
||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
_write_table_with_hf_images(table, dst_path, features=hf_features)
|
||||||
else:
|
else:
|
||||||
df.to_parquet(dst_path)
|
pq.write_table(table, dst_path)
|
||||||
return idx, (dst_chunk, dst_file)
|
return idx, (dst_chunk, dst_file)
|
||||||
|
|
||||||
src_size = get_parquet_file_size_in_mb(src_path)
|
src_size = get_parquet_file_size_in_mb(src_path)
|
||||||
@@ -602,22 +623,17 @@ def append_or_create_parquet_file(
|
|||||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
final_df = df
|
final_table = table
|
||||||
target_path = new_path
|
target_path = new_path
|
||||||
else:
|
else:
|
||||||
if contains_images:
|
existing_table = pq.read_table(dst_path)
|
||||||
# Use HuggingFace datasets to read existing data to preserve image format
|
final_table = pa.concat_tables([existing_table, table], promote_options="permissive")
|
||||||
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
|
|
||||||
existing_df = existing_ds.to_pandas()
|
|
||||||
else:
|
|
||||||
existing_df = pd.read_parquet(dst_path)
|
|
||||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
|
||||||
target_path = dst_path
|
target_path = dst_path
|
||||||
|
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
_write_table_with_hf_images(final_table, target_path, features=hf_features)
|
||||||
else:
|
else:
|
||||||
final_df.to_parquet(target_path)
|
pq.write_table(final_table, target_path)
|
||||||
|
|
||||||
return idx, (dst_chunk, dst_file)
|
return idx, (dst_chunk, dst_file)
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ from pathlib import Path
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.compute as pc
|
||||||
|
import pyarrow.dataset as pa_ds
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -89,8 +92,8 @@ def delete_episodes(
|
|||||||
Args:
|
Args:
|
||||||
dataset: The source LeRobotDataset.
|
dataset: The source LeRobotDataset.
|
||||||
episode_indices: List of episode indices to delete.
|
episode_indices: List of episode indices to delete.
|
||||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||||
"""
|
"""
|
||||||
if not episode_indices:
|
if not episode_indices:
|
||||||
raise ValueError("No episodes to delete")
|
raise ValueError("No episodes to delete")
|
||||||
@@ -152,7 +155,7 @@ def split_dataset(
|
|||||||
dataset: The source LeRobotDataset to split.
|
dataset: The source LeRobotDataset to split.
|
||||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||||
split names to fractions (must sum to <= 1.0).
|
split names to fractions (must sum to <= 1.0).
|
||||||
output_dir: Base directory for output datasets. If None, uses default location.
|
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
Split by specific episodes
|
Split by specific episodes
|
||||||
@@ -243,8 +246,8 @@ def merge_datasets(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
datasets: List of LeRobotDatasets to merge.
|
datasets: List of LeRobotDatasets to merge.
|
||||||
output_repo_id: Repository ID for the merged dataset.
|
output_repo_id: Merged dataset identifier.
|
||||||
output_dir: Directory to save the merged dataset. If None, uses default location.
|
output_dir: Root directory where the merged dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/output_repo_id.
|
||||||
"""
|
"""
|
||||||
if not datasets:
|
if not datasets:
|
||||||
raise ValueError("No datasets to merge")
|
raise ValueError("No datasets to merge")
|
||||||
@@ -288,8 +291,8 @@ def modify_features(
|
|||||||
dataset: The source LeRobotDataset.
|
dataset: The source LeRobotDataset.
|
||||||
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
|
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
|
||||||
remove_features: Optional feature name(s) to remove. Can be a single string or list.
|
remove_features: Optional feature name(s) to remove. Can be a single string or list.
|
||||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
New dataset with features modified.
|
New dataset with features modified.
|
||||||
@@ -390,8 +393,8 @@ def add_features(
|
|||||||
Args:
|
Args:
|
||||||
dataset: The source LeRobotDataset.
|
dataset: The source LeRobotDataset.
|
||||||
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
|
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
|
||||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
New dataset with all features added.
|
New dataset with all features added.
|
||||||
@@ -427,8 +430,8 @@ def remove_feature(
|
|||||||
Args:
|
Args:
|
||||||
dataset: The source LeRobotDataset.
|
dataset: The source LeRobotDataset.
|
||||||
feature_names: Name(s) of features to remove. Can be a single string or list.
|
feature_names: Name(s) of features to remove. Can be a single string or list.
|
||||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
New dataset with features removed.
|
New dataset with features removed.
|
||||||
@@ -496,13 +499,16 @@ def _copy_and_reindex_data(
|
|||||||
global_index = 0
|
global_index = 0
|
||||||
episode_data_metadata: dict[int, dict] = {}
|
episode_data_metadata: dict[int, dict] = {}
|
||||||
|
|
||||||
|
episode_keys = list(episode_mapping.keys())
|
||||||
|
ep_filter = pa_ds.field("episode_index").isin(episode_keys)
|
||||||
|
|
||||||
if dst_meta.tasks is None:
|
if dst_meta.tasks is None:
|
||||||
all_task_indices = set()
|
all_task_indices: set[int] = set()
|
||||||
for src_path in file_to_episodes:
|
for src_path in file_to_episodes:
|
||||||
df = pd.read_parquet(src_dataset.root / src_path)
|
table = pq.read_table(
|
||||||
mask = df["episode_index"].isin(list(episode_mapping.keys()))
|
src_dataset.root / src_path, columns=["episode_index", "task_index"], filters=ep_filter
|
||||||
task_series: pd.Series = df[mask]["task_index"]
|
)
|
||||||
all_task_indices.update(task_series.unique().tolist())
|
all_task_indices.update(pc.unique(table.column("task_index")).to_pylist())
|
||||||
tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices]
|
tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices]
|
||||||
dst_meta.save_episode_tasks(list(set(tasks)))
|
dst_meta.save_episode_tasks(list(set(tasks)))
|
||||||
|
|
||||||
@@ -514,52 +520,41 @@ def _copy_and_reindex_data(
|
|||||||
task_mapping[old_task_idx] = new_task_idx
|
task_mapping[old_task_idx] = new_task_idx
|
||||||
|
|
||||||
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
|
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
|
||||||
df = pd.read_parquet(src_dataset.root / src_path)
|
table = pq.read_table(src_dataset.root / src_path, filters=ep_filter)
|
||||||
|
|
||||||
all_episodes_in_file = set(df["episode_index"].unique())
|
|
||||||
episodes_to_keep = file_to_episodes[src_path]
|
episodes_to_keep = file_to_episodes[src_path]
|
||||||
|
|
||||||
if all_episodes_in_file == episodes_to_keep:
|
if table.num_rows == 0:
|
||||||
df["episode_index"] = df["episode_index"].replace(episode_mapping)
|
continue
|
||||||
df["index"] = range(global_index, global_index + len(df))
|
|
||||||
df["task_index"] = df["task_index"].replace(task_mapping)
|
|
||||||
|
|
||||||
first_ep_old_idx = min(episodes_to_keep)
|
table = _replace_column_values(table, "episode_index", episode_mapping)
|
||||||
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
|
col_pos = table.column_names.index("index")
|
||||||
chunk_idx = src_ep["data/chunk_index"]
|
new_indices = pa.array(range(global_index, global_index + table.num_rows), type=pa.int64())
|
||||||
file_idx = src_ep["data/file_index"]
|
table = table.set_column(col_pos, "index", new_indices)
|
||||||
else:
|
table = _replace_column_values(table, "task_index", task_mapping)
|
||||||
mask = df["episode_index"].isin(list(episode_mapping.keys()))
|
|
||||||
df = df[mask].copy().reset_index(drop=True)
|
|
||||||
|
|
||||||
if len(df) == 0:
|
first_ep_old_idx = min(episodes_to_keep)
|
||||||
continue
|
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
|
||||||
|
chunk_idx = src_ep["data/chunk_index"]
|
||||||
df["episode_index"] = df["episode_index"].replace(episode_mapping)
|
file_idx = src_ep["data/file_index"]
|
||||||
df["index"] = range(global_index, global_index + len(df))
|
|
||||||
df["task_index"] = df["task_index"].replace(task_mapping)
|
|
||||||
|
|
||||||
first_ep_old_idx = min(episodes_to_keep)
|
|
||||||
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
|
|
||||||
chunk_idx = src_ep["data/chunk_index"]
|
|
||||||
file_idx = src_ep["data/file_index"]
|
|
||||||
|
|
||||||
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
_write_parquet(df, dst_path, dst_meta)
|
_write_parquet(table, dst_path, dst_meta)
|
||||||
|
|
||||||
|
ep_col = table.column("episode_index").to_pylist()
|
||||||
|
idx_col = table.column("index").to_pylist()
|
||||||
for ep_old_idx in episodes_to_keep:
|
for ep_old_idx in episodes_to_keep:
|
||||||
ep_new_idx = episode_mapping[ep_old_idx]
|
ep_new_idx = episode_mapping[ep_old_idx]
|
||||||
ep_df = df[df["episode_index"] == ep_new_idx]
|
ep_indices = [idx_col[i] for i, e in enumerate(ep_col) if e == ep_new_idx]
|
||||||
episode_data_metadata[ep_new_idx] = {
|
episode_data_metadata[ep_new_idx] = {
|
||||||
"data/chunk_index": chunk_idx,
|
"data/chunk_index": chunk_idx,
|
||||||
"data/file_index": file_idx,
|
"data/file_index": file_idx,
|
||||||
"dataset_from_index": int(ep_df["index"].min()),
|
"dataset_from_index": min(ep_indices),
|
||||||
"dataset_to_index": int(ep_df["index"].max() + 1),
|
"dataset_to_index": max(ep_indices) + 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
global_index += len(df)
|
global_index += table.num_rows
|
||||||
|
|
||||||
return episode_data_metadata
|
return episode_data_metadata
|
||||||
|
|
||||||
@@ -910,15 +905,39 @@ def _copy_and_reindex_episodes_metadata(
|
|||||||
write_stats(filtered_stats, dst_meta.root)
|
write_stats(filtered_stats, dst_meta.root)
|
||||||
|
|
||||||
|
|
||||||
def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None:
|
def _replace_column_values(table: pa.Table, column: str, mapping: dict) -> pa.Table:
|
||||||
"""Write DataFrame to parquet
|
"""Replace values in a pyarrow Table column using a mapping dict."""
|
||||||
|
old_values = table.column(column).to_pylist()
|
||||||
|
new_values = [mapping.get(v, v) for v in old_values]
|
||||||
|
col_pos = table.column_names.index(column)
|
||||||
|
return table.set_column(col_pos, column, pa.array(new_values, type=table.schema.field(column).type))
|
||||||
|
|
||||||
|
|
||||||
|
def _write_parquet(
|
||||||
|
data: pd.DataFrame | pa.Table | dict, path: Path, meta: LeRobotDatasetMetadata
|
||||||
|
) -> None:
|
||||||
|
"""Write data to parquet.
|
||||||
|
|
||||||
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
|
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Input data as a pandas DataFrame, pyarrow Table, or dict of lists.
|
||||||
|
path: Destination parquet file path.
|
||||||
|
meta: Dataset metadata for feature schema.
|
||||||
"""
|
"""
|
||||||
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
|
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
|
||||||
|
|
||||||
|
if isinstance(data, pd.DataFrame):
|
||||||
|
data_dict = data.to_dict(orient="list")
|
||||||
|
elif isinstance(data, pa.Table):
|
||||||
|
data_dict = data.to_pydict()
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
data_dict = data
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported data type: {type(data)}")
|
||||||
|
|
||||||
hf_features = get_hf_features_from_features(meta.features)
|
hf_features = get_hf_features_from_features(meta.features)
|
||||||
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
|
ep_dataset = datasets.Dataset.from_dict(data_dict, features=hf_features, split="train")
|
||||||
|
|
||||||
if len(meta.image_keys) > 0:
|
if len(meta.image_keys) > 0:
|
||||||
ep_dataset = embed_images(ep_dataset)
|
ep_dataset = embed_images(ep_dataset)
|
||||||
@@ -1475,7 +1494,9 @@ def modify_tasks(
|
|||||||
|
|
||||||
# Collect all unique tasks and create new task mapping
|
# Collect all unique tasks and create new task mapping
|
||||||
unique_tasks = sorted(set(episode_to_task.values()))
|
unique_tasks = sorted(set(episode_to_task.values()))
|
||||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
new_task_df = pd.DataFrame(
|
||||||
|
{"task_index": list(range(len(unique_tasks)))}, index=pd.Index(unique_tasks, name="task")
|
||||||
|
)
|
||||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||||
|
|
||||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||||
@@ -1529,7 +1550,7 @@ def modify_tasks(
|
|||||||
|
|
||||||
def convert_image_to_video_dataset(
|
def convert_image_to_video_dataset(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
output_dir: Path,
|
output_dir: Path | None = None,
|
||||||
repo_id: str | None = None,
|
repo_id: str | None = None,
|
||||||
vcodec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
pix_fmt: str = "yuv420p",
|
pix_fmt: str = "yuv420p",
|
||||||
@@ -1548,8 +1569,8 @@ def convert_image_to_video_dataset(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: The source LeRobot dataset with images
|
dataset: The source LeRobot dataset with images
|
||||||
output_dir: Directory to save the new video dataset
|
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||||
vcodec: Video codec (default: libsvtav1)
|
vcodec: Video codec (default: libsvtav1)
|
||||||
pix_fmt: Pixel format (default: yuv420p)
|
pix_fmt: Pixel format (default: yuv420p)
|
||||||
g: Group of pictures size (default: 2)
|
g: Group of pictures size (default: 2)
|
||||||
@@ -1600,6 +1621,7 @@ def convert_image_to_video_dataset(
|
|||||||
# Video info will be updated after episodes are encoded
|
# Video info will be updated after episodes are encoded
|
||||||
|
|
||||||
# Create new metadata for video dataset
|
# Create new metadata for video dataset
|
||||||
|
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||||
new_meta = LeRobotDatasetMetadata.create(
|
new_meta = LeRobotDatasetMetadata.create(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
fps=dataset.meta.fps,
|
fps=dataset.meta.fps,
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ class LeRobotDatasetMetadata:
|
|||||||
if self.tasks is None:
|
if self.tasks is None:
|
||||||
new_tasks = tasks
|
new_tasks = tasks
|
||||||
task_indices = range(len(tasks))
|
task_indices = range(len(tasks))
|
||||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
|
self.tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(tasks, name="task"))
|
||||||
else:
|
else:
|
||||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||||
|
|||||||
@@ -341,6 +341,7 @@ def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None:
|
|||||||
|
|
||||||
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||||
|
tasks.index.name = "task"
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,8 +36,11 @@ Convert a local dataset (works in place):
|
|||||||
```bash
|
```bash
|
||||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||||
--repo-id=lerobot/pusht \
|
--repo-id=lerobot/pusht \
|
||||||
--root=/path/to/local/dataset/directory
|
--root=/path/to/local/dataset/directory \
|
||||||
--push-to-hub=false
|
--push-to-hub=false
|
||||||
|
|
||||||
|
N.B. Path semantics (v2): --root is the exact dataset folder containing
|
||||||
|
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -105,7 +108,7 @@ episodes.jsonl
|
|||||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||||
|
|
||||||
NEW
|
NEW
|
||||||
meta/episodes/chunk-000/episodes_000.parquet
|
meta/episodes/chunk-000/file_000.parquet
|
||||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||||
-------------------------
|
-------------------------
|
||||||
OLD
|
OLD
|
||||||
@@ -113,15 +116,16 @@ tasks.jsonl
|
|||||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||||
|
|
||||||
NEW
|
NEW
|
||||||
meta/tasks/chunk-000/file_000.parquet
|
meta/tasks.parquet
|
||||||
task_index | task
|
task_index | task
|
||||||
-------------------------
|
-------------------------
|
||||||
OLD
|
OLD
|
||||||
episodes_stats.jsonl
|
episodes_stats.jsonl
|
||||||
|
{"episode_index": 1, "stats": {"feature_name": {"min": ..., "max": ..., "mean": ..., "std": ..., "count": ...}}}
|
||||||
|
|
||||||
NEW
|
NEW
|
||||||
meta/episodes_stats/chunk-000/file_000.parquet
|
meta/episodes/chunk-000/file_000.parquet
|
||||||
episode_index | mean | std | min | max
|
episode_index | feature_name/min | feature_name/max | feature_name/mean | feature_name/std | feature_name/count
|
||||||
-------------------------
|
-------------------------
|
||||||
UPDATE
|
UPDATE
|
||||||
meta/info.json
|
meta/info.json
|
||||||
@@ -170,7 +174,7 @@ def convert_tasks(root, new_root):
|
|||||||
tasks, _ = legacy_load_tasks(root)
|
tasks, _ = legacy_load_tasks(root)
|
||||||
task_indices = tasks.keys()
|
task_indices = tasks.keys()
|
||||||
task_strings = tasks.values()
|
task_strings = tasks.values()
|
||||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
df_tasks = pd.DataFrame({"task_index": task_indices}, index=pd.Index(task_strings, name="task"))
|
||||||
write_tasks(df_tasks, new_root)
|
write_tasks(df_tasks, new_root)
|
||||||
|
|
||||||
|
|
||||||
@@ -201,7 +205,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
|||||||
|
|
||||||
image_keys = get_image_keys(root)
|
image_keys = get_image_keys(root)
|
||||||
|
|
||||||
ep_idx = 0
|
|
||||||
chunk_idx = 0
|
chunk_idx = 0
|
||||||
file_idx = 0
|
file_idx = 0
|
||||||
size_in_mb = 0
|
size_in_mb = 0
|
||||||
@@ -211,9 +214,24 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
|||||||
|
|
||||||
logging.info(f"Converting data files from {len(ep_paths)} episodes")
|
logging.info(f"Converting data files from {len(ep_paths)} episodes")
|
||||||
|
|
||||||
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
|
for ep_idx, ep_path in enumerate(tqdm.tqdm(ep_paths, desc="convert data files")):
|
||||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||||
|
|
||||||
|
# Check if we need to start a new file BEFORE creating metadata
|
||||||
|
if size_in_mb + ep_size_in_mb >= data_file_size_in_mb and len(paths_to_cat) > 0:
|
||||||
|
# Write the accumulated data files
|
||||||
|
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||||
|
|
||||||
|
# Move to next file
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||||
|
|
||||||
|
# Reset for the next file
|
||||||
|
size_in_mb = 0
|
||||||
|
num_frames += ep_num_frames # Still need to accumulate total frames
|
||||||
|
paths_to_cat = []
|
||||||
|
|
||||||
|
# Now create metadata with correct chunk/file indices
|
||||||
ep_metadata = {
|
ep_metadata = {
|
||||||
"episode_index": ep_idx,
|
"episode_index": ep_idx,
|
||||||
"data/chunk_index": chunk_idx,
|
"data/chunk_index": chunk_idx,
|
||||||
@@ -224,20 +242,7 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
|||||||
size_in_mb += ep_size_in_mb
|
size_in_mb += ep_size_in_mb
|
||||||
num_frames += ep_num_frames
|
num_frames += ep_num_frames
|
||||||
episodes_metadata.append(ep_metadata)
|
episodes_metadata.append(ep_metadata)
|
||||||
ep_idx += 1
|
paths_to_cat.append(ep_path)
|
||||||
|
|
||||||
if size_in_mb < data_file_size_in_mb:
|
|
||||||
paths_to_cat.append(ep_path)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if paths_to_cat:
|
|
||||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
|
||||||
|
|
||||||
# Reset for the next file
|
|
||||||
size_in_mb = ep_size_in_mb
|
|
||||||
paths_to_cat = [ep_path]
|
|
||||||
|
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
|
||||||
|
|
||||||
# Write remaining data if any
|
# Write remaining data if any
|
||||||
if paths_to_cat:
|
if paths_to_cat:
|
||||||
@@ -469,7 +474,7 @@ def convert_dataset(
|
|||||||
|
|
||||||
# Set root based on whether local dataset path is provided
|
# Set root based on whether local dataset path is provided
|
||||||
use_local_dataset = False
|
use_local_dataset = False
|
||||||
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
|
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root)
|
||||||
if root.exists():
|
if root.exists():
|
||||||
validate_local_dataset_version(root)
|
validate_local_dataset_version(root)
|
||||||
use_local_dataset = True
|
use_local_dataset = True
|
||||||
@@ -553,7 +558,7 @@ if __name__ == "__main__":
|
|||||||
"--root",
|
"--root",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Local directory to use for downloading/writing the dataset.",
|
help="Local directory to use for downloading/writing the dataset. Defaults to $HF_LEROBOT_HOME/repo_id.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--push-to-hub",
|
"--push-to-hub",
|
||||||
|
|||||||
@@ -132,10 +132,13 @@ def visualize_dataset(
|
|||||||
|
|
||||||
logging.info("Logging to Rerun")
|
logging.info("Logging to Rerun")
|
||||||
|
|
||||||
|
first_index = None
|
||||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||||
|
if first_index is None:
|
||||||
|
first_index = batch["index"][0].item()
|
||||||
# iterate over the batch
|
# iterate over the batch
|
||||||
for i in range(len(batch["index"])):
|
for i in range(len(batch["index"])):
|
||||||
rr.set_time("frame_index", sequence=batch["frame_index"][i].item())
|
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
|
||||||
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
|
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
|
||||||
|
|
||||||
# display each camera image
|
# display each camera image
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ This script allows you to delete episodes, split datasets, merge datasets,
|
|||||||
remove features, modify tasks, and convert image datasets to video format.
|
remove features, modify tasks, and convert image datasets to video format.
|
||||||
When new_repo_id is specified, creates a new dataset.
|
When new_repo_id is specified, creates a new dataset.
|
||||||
|
|
||||||
|
Path semantics (v2): --root and --new_root are exact dataset folders containing
|
||||||
|
meta/, data/, videos/. When omitted, defaults to $HF_LEROBOT_HOME/{repo_id}.
|
||||||
|
|
||||||
Usage Examples:
|
Usage Examples:
|
||||||
|
|
||||||
Delete episodes 0, 2, and 5 from a dataset:
|
Delete episodes 0, 2, and 5 from a dataset:
|
||||||
@@ -29,16 +32,31 @@ Delete episodes 0, 2, and 5 from a dataset:
|
|||||||
--operation.type delete_episodes \
|
--operation.type delete_episodes \
|
||||||
--operation.episode_indices "[0, 2, 5]"
|
--operation.episode_indices "[0, 2, 5]"
|
||||||
|
|
||||||
Delete episodes and save to a new dataset:
|
Delete episodes from a local dataset at a specific path:
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id lerobot/pusht \
|
--repo_id lerobot/pusht \
|
||||||
--new_repo_id lerobot/pusht_filtered \
|
--root /path/to/pusht \
|
||||||
--operation.type delete_episodes \
|
--operation.type delete_episodes \
|
||||||
--operation.episode_indices "[0, 2, 5]"
|
--operation.episode_indices "[0, 2, 5]"
|
||||||
|
|
||||||
Split dataset by fractions:
|
Delete episodes and save to a new dataset at a specific path and with a new repo_id:
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id lerobot/pusht \
|
--repo_id lerobot/pusht \
|
||||||
|
--new_repo_id lerobot/pusht_filtered \
|
||||||
|
--new_root /path/to/pusht_filtered \
|
||||||
|
--operation.type delete_episodes \
|
||||||
|
--operation.episode_indices "[0, 2, 5]"
|
||||||
|
|
||||||
|
Split dataset by fractions (pusht_train, pusht_val):
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type split \
|
||||||
|
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||||
|
|
||||||
|
Split dataset by fractions and save split datasets to a specific folder (base_folder/train, base_folder/val):
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--new_root /path/to/base_folder \
|
||||||
--operation.type split \
|
--operation.type split \
|
||||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||||
|
|
||||||
@@ -56,15 +74,29 @@ Split into more than two splits:
|
|||||||
|
|
||||||
Merge multiple datasets:
|
Merge multiple datasets:
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id lerobot/pusht_merged \
|
--new_repo_id lerobot/pusht_merged \
|
||||||
--operation.type merge \
|
--operation.type merge \
|
||||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||||
|
|
||||||
|
Merge multiple datasets to a specific output path:
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--new_repo_id lerobot/pusht_merged \
|
||||||
|
--new_root /path/to/pusht_merged \
|
||||||
|
--operation.type merge \
|
||||||
|
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||||
|
|
||||||
|
Merge multiple datasets from a list of local dataset paths:
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--new_repo_id lerobot/pusht_merged \
|
||||||
|
--operation.type merge \
|
||||||
|
--operation.repo_ids "['pusht_train', 'pusht_val']" \
|
||||||
|
--operation.roots "['/path/to/pusht_train', '/path/to/pusht_val']"
|
||||||
|
|
||||||
Remove camera feature:
|
Remove camera feature:
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id lerobot/pusht \
|
--repo_id lerobot/pusht \
|
||||||
--operation.type remove_feature \
|
--operation.type remove_feature \
|
||||||
--operation.feature_names "['observation.images.top']"
|
--operation.feature_names "['observation.image']"
|
||||||
|
|
||||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
@@ -88,8 +120,8 @@ Modify tasks - set default task with overrides for specific episodes (WARNING: m
|
|||||||
Convert image dataset to video format and save locally:
|
Convert image dataset to video format and save locally:
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
--repo_id lerobot/pusht_image \
|
--repo_id lerobot/pusht_image \
|
||||||
--operation.type convert_image_to_video \
|
--new_root /path/to/output/pusht_video \
|
||||||
--operation.output_dir /path/to/output/pusht_video
|
--operation.type convert_image_to_video
|
||||||
|
|
||||||
Convert image dataset to video format and save with new repo_id:
|
Convert image dataset to video format and save with new repo_id:
|
||||||
lerobot-edit-dataset \
|
lerobot-edit-dataset \
|
||||||
@@ -167,6 +199,7 @@ class SplitConfig(OperationConfig):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MergeConfig(OperationConfig):
|
class MergeConfig(OperationConfig):
|
||||||
repo_ids: list[str] | None = None
|
repo_ids: list[str] | None = None
|
||||||
|
roots: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
@OperationConfig.register_subclass("remove_feature")
|
@OperationConfig.register_subclass("remove_feature")
|
||||||
@@ -200,42 +233,46 @@ class ConvertImageToVideoConfig(OperationConfig):
|
|||||||
@OperationConfig.register_subclass("info")
|
@OperationConfig.register_subclass("info")
|
||||||
@dataclass
|
@dataclass
|
||||||
class InfoConfig(OperationConfig):
|
class InfoConfig(OperationConfig):
|
||||||
type: str = "info"
|
|
||||||
show_features: bool = False
|
show_features: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EditDatasetConfig:
|
class EditDatasetConfig:
|
||||||
repo_id: str
|
# Operation configuration.
|
||||||
operation: OperationConfig
|
operation: OperationConfig
|
||||||
# Parent cache directory. Each dataset lives at root/{repo_id}. If None, defaults to HF_LEROBOT_HOME.
|
# Input dataset identifier. Always required unless for Merge operation.
|
||||||
|
repo_id: str | None = None
|
||||||
|
# Root directory where the input dataset is stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||||
root: str | None = None
|
root: str | None = None
|
||||||
|
# Edited dataset identifier. When both new_repo_id (resp. new_root) and repo_id (resp. root) are identical, modifications are applied in-place and a backup of the original dataset is created. Required for Merge operation.
|
||||||
new_repo_id: str | None = None
|
new_repo_id: str | None = None
|
||||||
|
# Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/new_repo_id. For Split operation, this is the base directory for the split datasets.
|
||||||
|
new_root: str | None = None
|
||||||
|
# Upload dataset to Hugging Face hub.
|
||||||
push_to_hub: bool = False
|
push_to_hub: bool = False
|
||||||
|
|
||||||
|
|
||||||
def _resolve_root(root: str | None, repo_id: str) -> Path | None:
|
def get_output_path(
|
||||||
"""Translate a parent cache directory into the exact dataset path expected by LeRobotDataset."""
|
repo_id: str,
|
||||||
return Path(root) / repo_id if root else None
|
new_repo_id: str | None,
|
||||||
|
root: Path | str | None,
|
||||||
|
new_root: Path | str | None,
|
||||||
|
) -> tuple[str, Path]:
|
||||||
|
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||||
|
|
||||||
|
output_repo_id = new_repo_id if new_repo_id else repo_id
|
||||||
|
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
|
||||||
|
|
||||||
def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
|
# In case of in-place modification, create a backup of the original dataset (if it exists)
|
||||||
if new_repo_id:
|
if output_path == input_path:
|
||||||
output_repo_id = new_repo_id
|
backup_path = input_path.with_name(input_path.name + "_old")
|
||||||
output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id
|
|
||||||
else:
|
|
||||||
output_repo_id = repo_id
|
|
||||||
dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id
|
|
||||||
old_path = Path(str(dataset_path) + "_old")
|
|
||||||
|
|
||||||
if dataset_path.exists():
|
if input_path.exists():
|
||||||
if old_path.exists():
|
if backup_path.exists():
|
||||||
shutil.rmtree(old_path)
|
shutil.rmtree(backup_path)
|
||||||
shutil.move(str(dataset_path), str(old_path))
|
shutil.move(input_path, backup_path)
|
||||||
|
|
||||||
output_dir = dataset_path
|
return output_repo_id, output_path
|
||||||
|
|
||||||
return output_repo_id, output_dir
|
|
||||||
|
|
||||||
|
|
||||||
def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||||
@@ -245,13 +282,17 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
|||||||
if not cfg.operation.episode_indices:
|
if not cfg.operation.episode_indices:
|
||||||
raise ValueError("episode_indices must be specified for delete_episodes operation")
|
raise ValueError("episode_indices must be specified for delete_episodes operation")
|
||||||
|
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
output_repo_id, output_dir = get_output_path(
|
output_repo_id, output_dir = get_output_path(
|
||||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
cfg.repo_id,
|
||||||
|
new_repo_id=cfg.new_repo_id,
|
||||||
|
root=cfg.root,
|
||||||
|
new_root=cfg.new_root,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.new_repo_id is None:
|
# In case of in-place modification, make the dataset point to the backup directory
|
||||||
dataset.root = Path(str(dataset.root) + "_old")
|
if output_dir == dataset.root:
|
||||||
|
dataset.root = dataset.root.with_name(dataset.root.name + "_old")
|
||||||
|
|
||||||
logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}")
|
logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}")
|
||||||
new_dataset = delete_episodes(
|
new_dataset = delete_episodes(
|
||||||
@@ -278,19 +319,27 @@ def handle_split(cfg: EditDatasetConfig) -> None:
|
|||||||
"splits dict must be specified with split names as keys and fractions/episode lists as values"
|
"splits dict must be specified with split names as keys and fractions/episode lists as values"
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
if cfg.new_repo_id is not None:
|
||||||
|
logging.warning(
|
||||||
|
"split uses the original dataset identifier --repo_id to generate split names. The --new_repo_id parameter is ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
|
|
||||||
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
|
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
|
||||||
split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
|
split_datasets = split_dataset(
|
||||||
|
dataset,
|
||||||
|
splits=cfg.operation.splits,
|
||||||
|
output_dir=cfg.new_root,
|
||||||
|
)
|
||||||
|
|
||||||
for split_name, split_ds in split_datasets.items():
|
for split_name, split_ds in split_datasets.items():
|
||||||
split_repo_id = f"{cfg.repo_id}_{split_name}"
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames"
|
f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames"
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.push_to_hub:
|
if cfg.push_to_hub:
|
||||||
logging.info(f"Pushing {split_name} split to hub as {split_repo_id}")
|
logging.info(f"Pushing {split_name} split to hub as {split_ds.repo_id}")
|
||||||
LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub()
|
LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
@@ -301,20 +350,29 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
|||||||
if not cfg.operation.repo_ids:
|
if not cfg.operation.repo_ids:
|
||||||
raise ValueError("repo_ids must be specified for merge operation")
|
raise ValueError("repo_ids must be specified for merge operation")
|
||||||
|
|
||||||
if not cfg.repo_id:
|
if cfg.repo_id is not None or cfg.root is not None:
|
||||||
raise ValueError("repo_id must be specified as the output repository for merged dataset")
|
logging.warning(
|
||||||
|
"merge uses --new_repo_id and --new_root for the merged dataset. The --repo_id and --root parameters are ignored."
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
|
if cfg.operation.roots:
|
||||||
datasets = [
|
if len(cfg.operation.roots) != len(cfg.operation.repo_ids):
|
||||||
LeRobotDataset(repo_id, root=_resolve_root(cfg.root, repo_id)) for repo_id in cfg.operation.repo_ids
|
raise ValueError("repo_ids and roots must have the same length for merge operation")
|
||||||
]
|
logging.info(f"Loading {len(cfg.operation.roots)} datasets to merge")
|
||||||
|
datasets = [
|
||||||
|
LeRobotDataset(repo_id=repo_id, root=root)
|
||||||
|
for repo_id, root in zip(cfg.operation.repo_ids, cfg.operation.roots, strict=True)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
|
||||||
|
datasets = [LeRobotDataset(repo_id) for repo_id in cfg.operation.repo_ids]
|
||||||
|
|
||||||
output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
|
output_dir = Path(cfg.new_root) if cfg.new_root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||||
|
|
||||||
logging.info(f"Merging datasets into {cfg.repo_id}")
|
logging.info(f"Merging datasets into {cfg.new_repo_id}")
|
||||||
merged_dataset = merge_datasets(
|
merged_dataset = merge_datasets(
|
||||||
datasets,
|
datasets,
|
||||||
output_repo_id=cfg.repo_id,
|
output_repo_id=cfg.new_repo_id,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -324,7 +382,7 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.push_to_hub:
|
if cfg.push_to_hub:
|
||||||
logging.info(f"Pushing to hub as {cfg.repo_id}")
|
logging.info(f"Pushing to hub as {cfg.new_repo_id}")
|
||||||
LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub()
|
LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
@@ -335,13 +393,17 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
|||||||
if not cfg.operation.feature_names:
|
if not cfg.operation.feature_names:
|
||||||
raise ValueError("feature_names must be specified for remove_feature operation")
|
raise ValueError("feature_names must be specified for remove_feature operation")
|
||||||
|
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
output_repo_id, output_dir = get_output_path(
|
output_repo_id, output_dir = get_output_path(
|
||||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
cfg.repo_id,
|
||||||
|
new_repo_id=cfg.new_repo_id,
|
||||||
|
root=cfg.root,
|
||||||
|
new_root=cfg.new_root,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.new_repo_id is None:
|
# In case of in-place modification, make the dataset point to the backup directory
|
||||||
dataset.root = Path(str(dataset.root) + "_old")
|
if output_dir == dataset.root:
|
||||||
|
dataset.root = dataset.root.with_name(dataset.root.name + "_old")
|
||||||
|
|
||||||
logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}")
|
logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}")
|
||||||
new_dataset = remove_feature(
|
new_dataset = remove_feature(
|
||||||
@@ -369,11 +431,12 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
|||||||
if new_task is None and episode_tasks_raw is None:
|
if new_task is None and episode_tasks_raw is None:
|
||||||
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
|
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
|
||||||
|
|
||||||
# Warn about in-place modification behavior
|
if cfg.new_repo_id is not None or cfg.new_root is not None:
|
||||||
if cfg.new_repo_id is not None:
|
logging.warning(
|
||||||
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
|
"modify_tasks modifies datasets in-place. The --new_repo_id and --new_root parameters are ignored."
|
||||||
|
)
|
||||||
|
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
||||||
|
|
||||||
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
|
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
|
||||||
@@ -404,35 +467,33 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
|||||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||||
# instead of checking isinstance()
|
# instead of checking isinstance()
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
|
|
||||||
# Determine output directory and repo_id
|
# Determine output directory and repo_id
|
||||||
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
|
# Priority: 1) new_root, 2) new_repo_id, 3) operation.output_dir, 4) auto-generated name
|
||||||
output_dir_config = getattr(cfg.operation, "output_dir", None)
|
output_dir_config = getattr(cfg.operation, "output_dir", None)
|
||||||
|
if output_dir_config:
|
||||||
|
logging.warning(
|
||||||
|
"--operation.output_dir is deprecated and will be removed in future versions. "
|
||||||
|
"Please use --new_root instead."
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.new_repo_id:
|
if cfg.new_root:
|
||||||
# Use new_repo_id for both local storage and hub push
|
output_dir = Path(cfg.new_root)
|
||||||
|
output_repo_id = cfg.new_repo_id or f"{cfg.repo_id}_video"
|
||||||
|
logging.info(f"Saving to new_root: {output_dir} as {output_repo_id}")
|
||||||
|
elif cfg.new_repo_id:
|
||||||
output_repo_id = cfg.new_repo_id
|
output_repo_id = cfg.new_repo_id
|
||||||
# Place new dataset as a sibling to the original dataset
|
output_dir = HF_LEROBOT_HOME / cfg.new_repo_id
|
||||||
# Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir)
|
|
||||||
# Extract just the dataset name (after last slash) for the local directory
|
|
||||||
local_dir_name = cfg.new_repo_id.split("/")[-1]
|
|
||||||
output_dir = dataset.root.parent / local_dir_name
|
|
||||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}")
|
logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}")
|
||||||
elif output_dir_config:
|
elif output_dir_config:
|
||||||
# Use custom output directory for local-only storage
|
|
||||||
output_dir = Path(output_dir_config)
|
output_dir = Path(output_dir_config)
|
||||||
# Extract repo name from output_dir for the dataset
|
|
||||||
output_repo_id = output_dir.name
|
output_repo_id = output_dir.name
|
||||||
logging.info(f"Saving to local directory: {output_dir}")
|
logging.info(f"Saving to local directory: {output_dir} as {output_repo_id}")
|
||||||
else:
|
else:
|
||||||
# Auto-generate name: append "_video" to original repo_id
|
|
||||||
output_repo_id = f"{cfg.repo_id}_video"
|
output_repo_id = f"{cfg.repo_id}_video"
|
||||||
# Place new dataset as a sibling to the original dataset
|
output_dir = HF_LEROBOT_HOME / output_repo_id
|
||||||
# Extract just the dataset name (after last slash) for the local directory
|
logging.info(f"Saving to auto-generated location: {output_dir} as {output_repo_id}")
|
||||||
local_dir_name = output_repo_id.split("/")[-1]
|
|
||||||
output_dir = dataset.root.parent / local_dir_name
|
|
||||||
logging.info(f"Saving to auto-generated location: {output_dir}")
|
|
||||||
|
|
||||||
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||||
|
|
||||||
@@ -481,7 +542,7 @@ def handle_info(cfg: EditDatasetConfig):
|
|||||||
if not isinstance(cfg.operation, InfoConfig):
|
if not isinstance(cfg.operation, InfoConfig):
|
||||||
raise ValueError("Operation config must be InfoConfig")
|
raise ValueError("Operation config must be InfoConfig")
|
||||||
|
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=_resolve_root(cfg.root, cfg.repo_id))
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
||||||
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
||||||
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
||||||
@@ -507,8 +568,20 @@ def handle_info(cfg: EditDatasetConfig):
|
|||||||
sys.stdout.write(f"{feature_dump_str}\n")
|
sys.stdout.write(f"{feature_dump_str}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_config(cfg: EditDatasetConfig) -> None:
|
||||||
|
if isinstance(cfg.operation, MergeConfig):
|
||||||
|
if not cfg.new_repo_id:
|
||||||
|
raise ValueError("--new_repo_id is required for merge operation (the merged dataset identifier)")
|
||||||
|
else:
|
||||||
|
if not cfg.repo_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"--repo_id is required for {cfg.operation.type} operation (the input dataset identifier)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||||
|
_validate_config(cfg)
|
||||||
operation_type = cfg.operation.type
|
operation_type = cfg.operation.type
|
||||||
|
|
||||||
if operation_type == "delete_episodes":
|
if operation_type == "delete_episodes":
|
||||||
|
|||||||
Vendored
+1
-1
@@ -222,7 +222,7 @@ def tasks_factory():
|
|||||||
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
|
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
|
||||||
ids = list(range(total_tasks))
|
ids = list(range(total_tasks))
|
||||||
tasks = [f"Perform action {i}." for i in ids]
|
tasks = [f"Perform action {i}." for i in ids]
|
||||||
df = pd.DataFrame({"task_index": ids}, index=tasks)
|
df = pd.DataFrame({"task_index": ids}, index=pd.Index(tasks, name="task"))
|
||||||
return df
|
return df
|
||||||
|
|
||||||
return _create_tasks
|
return _create_tasks
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
|||||||
OperationConfig,
|
OperationConfig,
|
||||||
RemoveFeatureConfig,
|
RemoveFeatureConfig,
|
||||||
SplitConfig,
|
SplitConfig,
|
||||||
|
_validate_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -51,11 +52,23 @@ class TestOperationTypeParsing:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
|
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
|
||||||
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
|
cfg = parse_cfg(
|
||||||
|
["--repo_id", "test/repo", "--new_repo_id", "test/merged", "--operation.type", type_name]
|
||||||
|
)
|
||||||
assert isinstance(cfg.operation, expected_cls), (
|
assert isinstance(cfg.operation, expected_cls), (
|
||||||
f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}"
|
f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_merge_requires_new_repo_id(self):
|
||||||
|
cfg = parse_cfg(["--operation.type", "merge"])
|
||||||
|
with pytest.raises(ValueError, match="--new_repo_id is required for merge"):
|
||||||
|
_validate_config(cfg)
|
||||||
|
|
||||||
|
def test_non_merge_requires_repo_id(self):
|
||||||
|
cfg = parse_cfg(["--operation.type", "delete_episodes"])
|
||||||
|
with pytest.raises(ValueError, match="--repo_id is required for delete_episodes"):
|
||||||
|
_validate_config(cfg)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"type_name, expected_cls",
|
"type_name, expected_cls",
|
||||||
[
|
[
|
||||||
@@ -69,6 +82,8 @@ class TestOperationTypeParsing:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
|
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
|
||||||
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
|
cfg = parse_cfg(
|
||||||
|
["--repo_id", "test/repo", "--new_repo_id", "test/merged", "--operation.type", type_name]
|
||||||
|
)
|
||||||
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
|
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
|
||||||
assert resolved_name == type_name
|
assert resolved_name == type_name
|
||||||
|
|||||||
Reference in New Issue
Block a user