mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b2d3186011 | |||
| 5865170d36 | |||
| 2dd366436e | |||
| 5f15232271 | |||
| bc38261321 | |||
| aaf3707058 | |||
| 89bd58a9a2 | |||
| b22e0315b0 | |||
| fcbf550952 | |||
| af036ce57e | |||
| 1c388c0002 | |||
| 51d3822d75 | |||
| 6600b60e7f | |||
| adebbcf090 | |||
| 3615160d89 |
@@ -1,134 +0,0 @@
|
||||
# Action tokenizer benchmark
|
||||
|
||||
## Questions
|
||||
|
||||
What is the trade-off between:
|
||||
|
||||
- **Compression**: how many tokens are needed to represent an action chunk (e.g. horizon × action_dim floats)?
|
||||
- **Reconstruction quality**: how well does encode-then-decode preserve the original actions?
|
||||
- **Speed**: how long does encoding and decoding take per chunk?
|
||||
|
||||
How to choose an action tokenizer?
|
||||
|
||||
- Which tokenizer architecture (e.g. dct + BPE, DCT + BPE)?
|
||||
- Which **action horizon** and **encoded dimensions** to use?
|
||||
- Which **normalization** (QUANTILES, MEAN_STD, MIN_MAX) and **delta transform** (relative vs absolute actions)?
|
||||
- How do reconstruction error and compression ratio vary across datasets and tokenizer settings?
|
||||
|
||||
This benchmark loads action chunks from a LeRobot dataset using the same pipeline as `lerobot-train-tokenizer`, runs a trained action tokenizer in encode/decode mode, and reports reconstruction error, compression stats, and timing. Results are saved as JSON under `outputs/` for comparison and analysis.
|
||||
|
||||
## Variables
|
||||
|
||||
**Dataset & chunking**
|
||||
|
||||
- **repo_id**: LeRobot dataset (e.g. `lerobot/pusht`). Action statistics and normalization are taken from the dataset metadata when available.
|
||||
- **action_horizon**: Number of future steps per action chunk (must match the tokenizer’s training).
|
||||
- **encoded_dims**: Dimension ranges to encode (e.g. `0:6` or `0:6,7:14`). Must match the tokenizer.
|
||||
- **max_episodes**: Cap on episodes to load (default: all).
|
||||
- **sample_fraction**: Fraction of chunks to sample per episode (default `0.2`) to keep runtime manageable.
|
||||
|
||||
**Transform & normalization**
|
||||
|
||||
- **normalization_mode**: `IDENTITY`, `MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`. Should match the tokenizer’s training.
|
||||
- **delta_dims**: Comma-separated dimension indices for delta (relative) transform.
|
||||
- **use_delta_transform**: Whether to convert actions to relative to current state for those dimensions.
|
||||
- **state_key**: Dataset key for state (e.g. `observation.state`) used when applying delta transform.
|
||||
|
||||
**Tokenizer & evaluation**
|
||||
|
||||
- **action_tokenizer_path**: Path or HuggingFace repo id of the trained tokenizer (e.g. `outputs/wavetoken`).
|
||||
- **max_chunks_for_reconstruction**: Max number of chunks to use for reconstruction and timing (default `500`) to limit runtime.
|
||||
|
||||
### Main parameters
|
||||
|
||||
| parameter | default | description |
|
||||
| -------------------------------- | ---------------------------- | ------------------------------------------------ |
|
||||
| **action_tokenizer_path** | (required) | Path or Hub id of the trained action tokenizer. |
|
||||
| **repo_id** | (required) | LeRobot dataset repo id. |
|
||||
| **action_horizon** | `10` | Future steps per chunk. |
|
||||
| **encoded_dims** | `0:6` | Dimension ranges to encode (e.g. `0:6,7:14`). |
|
||||
| **normalization_mode** | `QUANTILES` | Normalization mode for actions. |
|
||||
| **max_episodes** | all | Max episodes to load. |
|
||||
| **sample_fraction** | `0.2` | Fraction of chunks sampled per episode. |
|
||||
| **max_chunks_for_reconstruction**| `500` | Chunks used for reconstruction and timing. |
|
||||
| **output_dir** | `outputs/action_tokenizer_benchmark` | Directory for results JSON. |
|
||||
|
||||
## Metrics
|
||||
|
||||
**Reconstruction (lower is better)**
|
||||
|
||||
- **reconstruction_mae**: Mean absolute error between original and decoded action chunks.
|
||||
- **reconstruction_mse**: Mean squared error.
|
||||
- **reconstruction_rmse**: Root mean squared error.
|
||||
- **reconstruction_max_abs_error**: Maximum absolute error over all dimensions and samples.
|
||||
- **per_dimension_mae**: MAE per action dimension (list of length `action_dim`).
|
||||
|
||||
**Compression**
|
||||
|
||||
- **compression_ratio**: Ratio (action_horizon × action_dim) / mean number of tokens. Higher means more compression.
|
||||
- **mean_token_length**, **std_token_length**: Mean and standard deviation of token count per chunk.
|
||||
- **min_token_length**, **max_token_length**: Min and max token count.
|
||||
- **p50_token_length**, **p99_token_length**: 50th and 99th percentile token counts.
|
||||
|
||||
**Timing (seconds per chunk)**
|
||||
|
||||
- **mean_encode_time_sec**: Mean time to encode one chunk.
|
||||
- **mean_decode_time_sec**: Mean time to decode one chunk.
|
||||
|
||||
The JSON output also includes **num_chunks_evaluated** and **total_chunks_available** for context.
|
||||
|
||||
## How the benchmark works
|
||||
|
||||
1. **Load dataset**: LeRobot dataset is loaded for the given `repo_id` and `root`.
|
||||
2. **Build action chunks**: For each episode (up to `max_episodes`), action chunks are built with the same logic as `lerobot-train-tokenizer`: sliding window of length `action_horizon`, optional delta transform, and per-episode sampling with `sample_fraction`.
|
||||
3. **Extract and normalize**: Only `encoded_dims` are kept. Normalization is applied using the dataset’s action stats when available, according to `normalization_mode`.
|
||||
4. **Encode / decode**: A random sample of chunks (size `max_chunks_for_reconstruction`) is encoded and then decoded with the tokenizer. Encode and decode times are recorded per chunk.
|
||||
5. **Compute metrics**: Reconstruction metrics are computed between original and decoded chunks; compression and timing stats are aggregated.
|
||||
6. **Save results**: A JSON file is written to `output_dir` with name `{timestamp}_{repo_id}_action_tokenizer_results.json`, containing the full config and all metrics.
|
||||
|
||||
The pipeline (chunking, dimensions, normalization, delta) must match how the tokenizer was trained; otherwise reconstruction error can be large or the tokenizer may raise.
|
||||
|
||||
## Caveats
|
||||
|
||||
- The tokenizer’s **action_horizon** and **action_dim** (and optionally DCT settings) are fixed at training time. The benchmark infers dimensions from the dataset and encoded dims; the tokenizer path must correspond to a model trained with the same horizon and encoded dimensions.
|
||||
- Reconstruction is evaluated in **normalized space** (the same space the tokenizer sees). For interpretation in raw action space, you would need to invert normalization outside this script.
|
||||
- Only one tokenizer and one dataset are evaluated per run. To compare tokenizers or datasets, run the script multiple times and compare the saved JSON files.
|
||||
|
||||
## Example
|
||||
|
||||
Quick run with a local tokenizer and a small number of episodes:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--max-episodes=50 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
|
||||
With delta transform and custom encoded dimensions:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--encoded-dims=0:6,7:14 \
|
||||
--delta-dims=0,1,2,3,4,5 \
|
||||
--use-delta-transform \
|
||||
--normalization-mode=QUANTILES \
|
||||
--max-chunks-for-reconstruction=500 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
|
||||
Results are written to e.g. `outputs/action_tokenizer_benchmark/2026-02-12_14-30-00_lerobot_pusht_action_tokenizer_results.json`.
|
||||
|
||||
## Results
|
||||
|
||||
Results are stored as JSON in the directory given by `--output-dir` (default: `outputs/action_tokenizer_benchmark`). Each file contains:
|
||||
|
||||
- **config**: All script arguments (tokenizer path, repo_id, action_horizon, encoded_dims, normalization_mode, etc.) for reproducibility.
|
||||
- **metrics**: All reconstruction, compression, and timing metrics described above.
|
||||
|
||||
To compare runs, load and diff or aggregate these JSON files with your own scripts or notebooks.
|
||||
@@ -1,442 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Benchmark action tokenization: reconstruction error, compression ratio, and timing.
|
||||
|
||||
Loads action chunks from a LeRobot dataset, encodes/decodes them with a trained action
|
||||
tokenizer, and reports:
|
||||
- Reconstruction: MAE, MSE, RMSE, max absolute error, per-dimension MAE
|
||||
- Jerk: mean absolute jerk (original and reconstructed), jerk reconstruction MAE
|
||||
- Compression: ratio (input size / mean tokens), token length stats
|
||||
- Timing: mean encode/decode time per chunk
|
||||
|
||||
Results are saved to outputs/action_tokenizer_benchmark/<timestamp>_results.json.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--max-episodes=50 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
# Optional: use same helpers as train script if we want to avoid duplication
|
||||
from lerobot.scripts.lerobot_train_tokenizer import (
|
||||
apply_normalization,
|
||||
process_episode,
|
||||
)
|
||||
|
||||
|
||||
def load_action_chunks(
|
||||
repo_id: str,
|
||||
root: str | None,
|
||||
action_horizon: int,
|
||||
max_episodes: int | None,
|
||||
sample_fraction: float,
|
||||
encoded_dims: str,
|
||||
delta_dims: str | None,
|
||||
use_delta_transform: bool,
|
||||
state_key: str,
|
||||
normalization_mode: NormalizationMode,
|
||||
):
|
||||
"""Load and normalize action chunks from a LeRobot dataset (same pipeline as training)."""
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
||||
num_episodes = dataset.num_episodes
|
||||
if max_episodes is not None:
|
||||
num_episodes = min(max_episodes, num_episodes)
|
||||
|
||||
# Parse encoded dims
|
||||
encoded_dim_ranges = []
|
||||
for range_str in encoded_dims.split(","):
|
||||
start, end = map(int, range_str.strip().split(":"))
|
||||
encoded_dim_ranges.append((start, end))
|
||||
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
||||
|
||||
delta_dim_list = None
|
||||
if delta_dims is not None and delta_dims.strip():
|
||||
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")]
|
||||
|
||||
all_chunks = []
|
||||
for ep_idx in range(num_episodes):
|
||||
chunks = process_episode(
|
||||
(
|
||||
dataset,
|
||||
ep_idx,
|
||||
action_horizon,
|
||||
delta_dim_list,
|
||||
sample_fraction,
|
||||
state_key,
|
||||
use_delta_transform,
|
||||
)
|
||||
)
|
||||
if chunks is not None:
|
||||
all_chunks.append(chunks)
|
||||
|
||||
if not all_chunks:
|
||||
raise ValueError("No action chunks collected. Check action_horizon and dataset.")
|
||||
|
||||
all_chunks = np.concatenate(all_chunks, axis=0)
|
||||
|
||||
# Extract encoded dimensions only
|
||||
encoded_chunks = []
|
||||
for start, end in encoded_dim_ranges:
|
||||
encoded_chunks.append(all_chunks[:, :, start:end])
|
||||
encoded_chunks = np.concatenate(encoded_chunks, axis=-1)
|
||||
|
||||
# Normalize
|
||||
norm_stats = dataset.meta.stats
|
||||
if norm_stats is not None and ACTION in norm_stats:
|
||||
action_stats = norm_stats[ACTION]
|
||||
encoded_dim_indices = []
|
||||
for start, end in encoded_dim_ranges:
|
||||
encoded_dim_indices.extend(range(start, end))
|
||||
encoded_dim_indices = np.array(encoded_dim_indices)
|
||||
encoded_stats = {}
|
||||
for stat_name, stat_values in action_stats.items():
|
||||
if isinstance(stat_values, (list, np.ndarray)):
|
||||
stat_array = np.array(stat_values)
|
||||
if len(stat_array) > max(encoded_dim_indices):
|
||||
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
|
||||
if encoded_stats:
|
||||
try:
|
||||
encoded_chunks = apply_normalization(
|
||||
encoded_chunks, encoded_stats, normalization_mode, eps=1e-8
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return encoded_chunks, total_encoded_dims, action_horizon, dataset.repo_id
|
||||
|
||||
|
||||
def compute_reconstruction_metrics(original: np.ndarray, reconstructed: np.ndarray):
|
||||
"""Compute reconstruction error metrics (original and reconstructed same shape [N, T, D])."""
|
||||
diff = reconstructed - original
|
||||
mae = float(np.mean(np.abs(diff)))
|
||||
mse = float(np.mean(diff**2))
|
||||
rmse = float(np.sqrt(mse))
|
||||
max_abs_err = float(np.max(np.abs(diff)))
|
||||
|
||||
# Per-dimension MAE (over N and T)
|
||||
per_dim_mae = np.mean(np.abs(diff), axis=(0, 1))
|
||||
per_dim_mae = per_dim_mae.tolist()
|
||||
|
||||
return {
|
||||
"reconstruction_mae": mae,
|
||||
"reconstruction_mse": mse,
|
||||
"reconstruction_rmse": rmse,
|
||||
"reconstruction_max_abs_error": max_abs_err,
|
||||
"per_dimension_mae": per_dim_mae,
|
||||
}
|
||||
|
||||
|
||||
def compute_jerk_metrics(original: np.ndarray, reconstructed: np.ndarray) -> dict:
|
||||
"""Compute jerk (3rd derivative of action w.r.t. time) metrics.
|
||||
|
||||
Args:
|
||||
original: Action chunks [N, T, D].
|
||||
reconstructed: Reconstructed action chunks [N, T, D].
|
||||
|
||||
Returns:
|
||||
Dict with mean absolute jerk for original, reconstructed, and jerk reconstruction MAE.
|
||||
"""
|
||||
# Jerk = 3rd discrete difference along time axis; need T >= 4
|
||||
if original.shape[1] < 4:
|
||||
return {}
|
||||
jerk_orig = np.diff(original, n=3, axis=1) # (N, T-3, D)
|
||||
jerk_recon = np.diff(reconstructed, n=3, axis=1)
|
||||
mae_jerk_orig = float(np.mean(np.abs(jerk_orig)))
|
||||
mae_jerk_recon = float(np.mean(np.abs(jerk_recon)))
|
||||
jerk_reconstruction_mae = float(np.mean(np.abs(jerk_recon - jerk_orig)))
|
||||
return {
|
||||
"jerk_mae_original": mae_jerk_orig,
|
||||
"jerk_mae_reconstructed": mae_jerk_recon,
|
||||
"jerk_reconstruction_mae": jerk_reconstruction_mae,
|
||||
}
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
action_chunks: np.ndarray,
|
||||
action_horizon: int,
|
||||
action_dim: int,
|
||||
tokenizer_path: str,
|
||||
max_chunks_for_reconstruction: int | None = 500,
|
||||
):
|
||||
"""Encode/decode action chunks and compute metrics."""
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
|
||||
|
||||
n_chunks = len(action_chunks)
|
||||
sample_size = n_chunks
|
||||
if max_chunks_for_reconstruction is not None:
|
||||
sample_size = min(max_chunks_for_reconstruction, n_chunks)
|
||||
rng = np.random.RandomState(42)
|
||||
indices = rng.choice(n_chunks, size=sample_size, replace=False)
|
||||
sample_chunks = action_chunks[indices]
|
||||
|
||||
# Encode
|
||||
token_lengths = []
|
||||
encode_times = []
|
||||
all_tokens = []
|
||||
for i in range(len(sample_chunks)):
|
||||
chunk = sample_chunks[i : i + 1]
|
||||
t0 = time.perf_counter()
|
||||
tokens = processor(chunk)[0]
|
||||
encode_times.append(time.perf_counter() - t0)
|
||||
if isinstance(tokens, list):
|
||||
token_lengths.append(len(tokens))
|
||||
all_tokens.append(tokens)
|
||||
else:
|
||||
n = tokens.shape[0] if hasattr(tokens, "shape") else len(tokens)
|
||||
token_lengths.append(n)
|
||||
all_tokens.append(tokens.tolist() if hasattr(tokens, "tolist") else list(tokens))
|
||||
|
||||
# Decode (processor keeps time_horizon/action_dim from encode)
|
||||
decoded_list = []
|
||||
decode_times = []
|
||||
for i, tok_list in enumerate(all_tokens):
|
||||
t0 = time.perf_counter()
|
||||
recon = processor.decode(
|
||||
[tok_list],
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
decode_times.append(time.perf_counter() - t0)
|
||||
decoded_list.append(recon)
|
||||
decoded = np.concatenate(decoded_list, axis=0)
|
||||
|
||||
# Reconstruction metrics
|
||||
metrics = compute_reconstruction_metrics(sample_chunks, decoded)
|
||||
|
||||
# Jerk metrics (3rd derivative along time)
|
||||
jerk_metrics = compute_jerk_metrics(sample_chunks, decoded)
|
||||
metrics.update(jerk_metrics)
|
||||
|
||||
# Compression
|
||||
token_lengths = np.array(token_lengths)
|
||||
input_size = action_horizon * action_dim
|
||||
compression_ratio = input_size / float(np.mean(token_lengths))
|
||||
metrics["compression_ratio"] = compression_ratio
|
||||
metrics["mean_token_length"] = float(np.mean(token_lengths))
|
||||
metrics["std_token_length"] = float(np.std(token_lengths))
|
||||
metrics["min_token_length"] = int(np.min(token_lengths))
|
||||
metrics["max_token_length"] = int(np.max(token_lengths))
|
||||
metrics["p50_token_length"] = float(np.percentile(token_lengths, 50))
|
||||
metrics["p99_token_length"] = float(np.percentile(token_lengths, 99))
|
||||
|
||||
# Timing (seconds per chunk)
|
||||
metrics["mean_encode_time_sec"] = float(np.mean(encode_times))
|
||||
metrics["mean_decode_time_sec"] = float(np.mean(decode_times))
|
||||
metrics["num_chunks_evaluated"] = sample_size
|
||||
metrics["total_chunks_available"] = n_chunks
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main(
|
||||
action_tokenizer_path: str,
|
||||
repo_id: str,
|
||||
root: str | None = None,
|
||||
action_horizon: int = 10,
|
||||
max_episodes: int | None = 100,
|
||||
sample_fraction: float = 0.2,
|
||||
encoded_dims: str = "0:6",
|
||||
delta_dims: str | None = None,
|
||||
use_delta_transform: bool = False,
|
||||
state_key: str = OBS_STATE,
|
||||
normalization_mode: str = "QUANTILES",
|
||||
max_chunks_for_reconstruction: int | None = 500,
|
||||
output_dir: str | None = None,
|
||||
):
|
||||
if output_dir is None:
|
||||
output_dir = "outputs/action_tokenizer_benchmark"
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
norm_mode = NormalizationMode(normalization_mode)
|
||||
except ValueError:
|
||||
norm_mode = NormalizationMode.QUANTILES
|
||||
|
||||
print("Loading action chunks...")
|
||||
encoded_chunks, action_dim, horizon, _ = load_action_chunks(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
action_horizon=action_horizon,
|
||||
max_episodes=max_episodes,
|
||||
sample_fraction=sample_fraction,
|
||||
encoded_dims=encoded_dims,
|
||||
delta_dims=delta_dims,
|
||||
use_delta_transform=use_delta_transform,
|
||||
state_key=state_key,
|
||||
normalization_mode=norm_mode,
|
||||
)
|
||||
print(f"Loaded {len(encoded_chunks)} chunks, shape {encoded_chunks.shape} (H={horizon}, D={action_dim})")
|
||||
|
||||
print("Running tokenizer benchmark...")
|
||||
metrics = run_benchmark(
|
||||
action_chunks=encoded_chunks,
|
||||
action_horizon=horizon,
|
||||
action_dim=action_dim,
|
||||
tokenizer_path=action_tokenizer_path,
|
||||
max_chunks_for_reconstruction=max_chunks_for_reconstruction,
|
||||
)
|
||||
|
||||
# Attach config for reproducibility
|
||||
results = {
|
||||
"config": {
|
||||
"action_tokenizer_path": action_tokenizer_path,
|
||||
"repo_id": repo_id,
|
||||
"action_horizon": action_horizon,
|
||||
"max_episodes": max_episodes,
|
||||
"sample_fraction": sample_fraction,
|
||||
"encoded_dims": encoded_dims,
|
||||
"delta_dims": delta_dims,
|
||||
"use_delta_transform": use_delta_transform,
|
||||
"state_key": state_key,
|
||||
"normalization_mode": normalization_mode,
|
||||
},
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
safe_repo = repo_id.replace("/", "_")
|
||||
out_file = output_path / f"{timestamp}_{safe_repo}_action_tokenizer_results.json"
|
||||
with open(out_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"Results saved to {out_file}")
|
||||
print("Metrics:")
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, list):
|
||||
print(f" {k}: (length {len(v)})")
|
||||
else:
|
||||
print(f" {k}: {v}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark action tokenization (reconstruction error, compression, timing)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action-tokenizer-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HuggingFace repo id of the trained action tokenizer (e.g. outputs/wavetoken).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="LeRobot dataset repo id (e.g. lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Root directory for LeRobot datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action-horizon",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of future steps per action chunk.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-episodes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max episodes to use (default: all).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-fraction",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Fraction of chunks to sample per episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoded-dims",
|
||||
type=str,
|
||||
default="0:6",
|
||||
help="Dimension ranges to encode (e.g. 0:6,7:14).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta-dims",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated dimensions for delta transform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-delta-transform",
|
||||
action="store_true",
|
||||
help="Apply delta (relative) transform to specified dimensions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state-key",
|
||||
type=str,
|
||||
default=OBS_STATE,
|
||||
help="Dataset key for state (for delta transform).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalization-mode",
|
||||
type=str,
|
||||
default="QUANTILES",
|
||||
choices=[m.value for m in NormalizationMode],
|
||||
help="Normalization mode for actions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-chunks-for-reconstruction",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Max chunks to use for reconstruction metrics (default: 500).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="outputs/action_tokenizer_benchmark",
|
||||
help="Directory to save results JSON (default: outputs/action_tokenizer_benchmark).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(
|
||||
action_tokenizer_path=args.action_tokenizer_path,
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
action_horizon=args.action_horizon,
|
||||
max_episodes=args.max_episodes,
|
||||
sample_fraction=args.sample_fraction,
|
||||
encoded_dims=args.encoded_dims,
|
||||
delta_dims=args.delta_dims,
|
||||
use_delta_transform=args.use_delta_transform,
|
||||
state_key=args.state_key,
|
||||
normalization_mode=args.normalization_mode,
|
||||
max_chunks_for_reconstruction=args.max_chunks_for_reconstruction,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
+42
-42
@@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
@@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
@@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
lerobot/aloha_mobile_shrimp_image \
|
||||
lerobot/paris_street \
|
||||
lerobot/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
@@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
|
||||
@@ -185,7 +185,7 @@ echo $HF_USER
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
lerobot-record \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
|
||||
@@ -224,7 +224,7 @@ lerobot-record \
|
||||
--teleop.port=/dev/tty.usbmodem1201 \
|
||||
--teleop.id=right \
|
||||
--teleop.side=right \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--dataset.single_task="Hand recording test with video data" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
@@ -241,7 +241,7 @@ lerobot-replay \
|
||||
--robot.port=/dev/tty.usbmodem58760432281 \
|
||||
--robot.id=right \
|
||||
--robot.side=right \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_camera \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_camera \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
@@ -249,13 +249,13 @@ lerobot-replay \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=nepyope/hand_record_test_with_video_data \
|
||||
--dataset.repo_id=<USER>/hand_record_test_with_video_data \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/hopejr_hand \
|
||||
--job_name=hopejr \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=nepyope/hand_test_policy
|
||||
--policy.repo_id=<USER>/hand_test_policy
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
@@ -270,7 +270,7 @@ lerobot-record \
|
||||
--robot.side=right \
|
||||
--robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=nepyope/eval_hopejr \
|
||||
--dataset.repo_id=<USER>/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
|
||||
+1
-1
@@ -60,7 +60,7 @@ policy.type=pi0
|
||||
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=./outputs/pi0_training \
|
||||
|
||||
@@ -56,7 +56,7 @@ policy.type=pi05
|
||||
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py\
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--output_dir=./outputs/pi05_training \
|
||||
|
||||
@@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
@@ -288,7 +288,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
@@ -307,7 +307,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
@@ -468,7 +468,7 @@ This script:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
|
||||
@@ -216,7 +216,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
@@ -266,7 +266,7 @@ lerobot-teleoperate \
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
lerobot-record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
|
||||
@@ -12,6 +12,7 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc.
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
@@ -156,6 +157,30 @@ lerobot-edit-dataset \
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Show the information of datasets
|
||||
|
||||
Show the information of datasets such as number of episode, number of frame, File size and so on.
|
||||
No change will be made to the dataset
|
||||
|
||||
```bash
|
||||
|
||||
# Show dataset information without feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
|
||||
# Show dataset information with feature details
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `parameters`: The flag to control show or no show dataset information with feature details.(default=false)
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
@@ -45,7 +45,7 @@ policy.type=wall_x
|
||||
For training WallX, you can use the standard LeRobot training script with the appropriate configuration:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=wall_x \
|
||||
--output_dir=./outputs/wallx_training \
|
||||
|
||||
@@ -154,7 +154,7 @@ lerobot-train \
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--dataset.repo_id=<USER>/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.episode=2
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,726 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Mirror a bimanual dataset in parallel with DataTrove + SLURM, then double it.
|
||||
|
||||
Workflow:
|
||||
1) Split source episodes across `num_shards` ranks and mirror each shard in parallel.
|
||||
2) Aggregate mirrored shards into one mirrored dataset.
|
||||
3) Aggregate [original, mirrored] into a final doubled dataset.
|
||||
|
||||
Example:
|
||||
python examples/port_datasets/slurm_mirror_dataset.py \
|
||||
--repo-id=pepijn/openarm_bimanual \
|
||||
--output-repo-id=pepijn/openarm_bimanual_doubled \
|
||||
--partition=hopper-cpu \
|
||||
--num-shards=256 \
|
||||
--workers=64 \
|
||||
--cpus-per-task=8 \
|
||||
--mem-per-cpu=4G
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OPENARM_MIRRORING_MASK = {
|
||||
"joint_1": -1,
|
||||
"joint_2": -1,
|
||||
"joint_3": -1,
|
||||
"joint_4": 1,
|
||||
"joint_5": -1,
|
||||
"joint_6": -1,
|
||||
"joint_7": -1,
|
||||
"gripper": 1,
|
||||
}
|
||||
|
||||
|
||||
def get_mirroring_mask(robot_type: str | None) -> dict[str, int]:
|
||||
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
|
||||
return OPENARM_MIRRORING_MASK
|
||||
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
|
||||
|
||||
|
||||
def swap_left_right_name(name: str) -> str:
|
||||
value = name.replace("left_", "LEFT_PLACEHOLDER_")
|
||||
value = value.replace("right_", "left_")
|
||||
value = value.replace("LEFT_PLACEHOLDER_", "right_")
|
||||
return value
|
||||
|
||||
|
||||
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
|
||||
mirrored_names = [swap_left_right_name(n) for n in names]
|
||||
old_to_new_idx = {}
|
||||
for old_idx, old_name in enumerate(names):
|
||||
new_name = swap_left_right_name(old_name)
|
||||
new_idx = mirrored_names.index(new_name)
|
||||
old_to_new_idx[old_idx] = new_idx
|
||||
return mirrored_names, old_to_new_idx
|
||||
|
||||
|
||||
def _get_axis_names(feature: dict[str, Any]) -> list[str] | None:
|
||||
names = feature.get("names")
|
||||
if isinstance(names, list):
|
||||
return names
|
||||
if isinstance(names, dict):
|
||||
axes = names.get("axes")
|
||||
if isinstance(axes, list):
|
||||
return axes
|
||||
return None
|
||||
|
||||
|
||||
def _to_numpy(value: Any) -> Any:
|
||||
if isinstance(value, np.ndarray):
|
||||
return value
|
||||
if hasattr(value, "detach"):
|
||||
return value.detach().cpu().numpy()
|
||||
if hasattr(value, "cpu") and hasattr(value, "numpy"):
|
||||
return value.cpu().numpy()
|
||||
if hasattr(value, "numpy"):
|
||||
return value.numpy()
|
||||
return value
|
||||
|
||||
|
||||
def apply_mirroring_mask(value: float, axis_name: str, mirroring_mask: dict[str, int]) -> float:
|
||||
if axis_name.startswith("left_") or axis_name.startswith("right_"):
|
||||
axis_name = axis_name.split("_", 1)[1]
|
||||
joint_name = axis_name.split(".")[0]
|
||||
return value * mirroring_mask.get(joint_name, 1)
|
||||
|
||||
|
||||
def mirror_vector_feature(
|
||||
value: Any,
|
||||
feature: dict[str, Any],
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> Any:
|
||||
array = _to_numpy(value)
|
||||
if not isinstance(array, np.ndarray) or array.ndim != 1:
|
||||
return array
|
||||
|
||||
names = _get_axis_names(feature)
|
||||
if names is None or len(names) != len(array):
|
||||
return array
|
||||
|
||||
mirrored_names, index_mapping = mirror_feature_names(names)
|
||||
mirrored = np.zeros_like(array)
|
||||
for old_idx, new_idx in index_mapping.items():
|
||||
mirrored[new_idx] = apply_mirroring_mask(array[old_idx], mirrored_names[new_idx], mirroring_mask)
|
||||
return mirrored
|
||||
|
||||
|
||||
def flip_horizontal(value: Any, expected_shape: list[int] | tuple[int, ...]) -> Any:
|
||||
array = _to_numpy(value)
|
||||
if not isinstance(array, np.ndarray) or array.ndim != 3:
|
||||
return array
|
||||
|
||||
expected_shape = tuple(expected_shape)
|
||||
if array.shape == expected_shape:
|
||||
return np.flip(array, axis=1).copy() # HWC
|
||||
|
||||
if len(expected_shape) == 3:
|
||||
c, h, w = expected_shape
|
||||
if array.shape == (c, h, w):
|
||||
return np.flip(array, axis=2).copy() # CHW
|
||||
|
||||
# Conservative fallback for unexpected layouts.
|
||||
return np.flip(array, axis=-1).copy()
|
||||
|
||||
|
||||
def build_mirrored_features(features: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
mirrored = {}
|
||||
for key, feature in features.items():
|
||||
new_key = swap_left_right_name(key)
|
||||
new_feature = copy.deepcopy(feature)
|
||||
names = new_feature.get("names")
|
||||
if isinstance(names, list):
|
||||
new_feature["names"] = [swap_left_right_name(name) for name in names]
|
||||
elif isinstance(names, dict) and isinstance(names.get("axes"), list):
|
||||
new_feature["names"]["axes"] = [swap_left_right_name(name) for name in names["axes"]]
|
||||
mirrored[new_key] = new_feature
|
||||
return mirrored
|
||||
|
||||
|
||||
def build_mirrored_frame(
|
||||
item: dict[str, Any],
|
||||
source_features: dict[str, dict[str, Any]],
|
||||
mirroring_mask: dict[str, int],
|
||||
) -> dict[str, Any]:
|
||||
frame = {}
|
||||
for key, feature in source_features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
|
||||
value = item[key]
|
||||
if key in {"action", "observation.state"}:
|
||||
value = mirror_vector_feature(value, feature, mirroring_mask)
|
||||
elif feature["dtype"] in {"video", "image"}:
|
||||
value = flip_horizontal(value, feature["shape"])
|
||||
else:
|
||||
value = _to_numpy(value)
|
||||
|
||||
frame[swap_left_right_name(key)] = value
|
||||
|
||||
frame["task"] = item["task"]
|
||||
if "timestamp" in item:
|
||||
ts = _to_numpy(item["timestamp"])
|
||||
frame["timestamp"] = float(ts.item() if hasattr(ts, "item") else ts)
|
||||
return frame
|
||||
|
||||
|
||||
def _resolve_source_root(repo_id: str, root: Path | None) -> Path:
|
||||
source_meta = LeRobotDatasetMetadata(repo_id=repo_id, root=root)
|
||||
return source_meta.root
|
||||
|
||||
|
||||
def _get_work_dir(output_repo_id: str, work_dir: Path | None) -> Path:
|
||||
if work_dir is not None:
|
||||
return work_dir
|
||||
safe_name = output_repo_id.replace("/", "__")
|
||||
return HF_LEROBOT_HOME / "_mirror_work" / safe_name
|
||||
|
||||
|
||||
def _get_shard_root(work_dir: Path, world_size: int, rank: int) -> Path:
|
||||
return work_dir / "mirrored_shards" / f"world_{world_size}_rank_{rank}"
|
||||
|
||||
|
||||
def _is_valid_dataset_root(root: Path) -> bool:
|
||||
return (root / "meta" / "info.json").exists()
|
||||
|
||||
|
||||
def mirror_shard(
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
shard_root: Path,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
) -> None:
|
||||
source_dataset = LeRobotDataset(repo_id=repo_id, root=source_root)
|
||||
selected_episodes = list(range(rank, source_dataset.meta.total_episodes, world_size))
|
||||
|
||||
if len(selected_episodes) == 0:
|
||||
logger.info("Rank %s has no episodes assigned. Skipping.", rank)
|
||||
return
|
||||
|
||||
if shard_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(shard_root)
|
||||
elif _is_valid_dataset_root(shard_root):
|
||||
logger.info("Rank %s shard already exists at %s. Skipping.", rank, shard_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Shard root {shard_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
mirroring_mask = get_mirroring_mask(source_dataset.meta.robot_type)
|
||||
mirrored_features = build_mirrored_features(source_dataset.meta.features)
|
||||
|
||||
shard_repo_name = f"{mirrored_repo_id}_world_{world_size}_rank_{rank}"
|
||||
mirrored_dataset = LeRobotDataset.create(
|
||||
repo_id=shard_repo_name,
|
||||
root=shard_root,
|
||||
fps=source_dataset.meta.fps,
|
||||
features=mirrored_features,
|
||||
robot_type=source_dataset.meta.robot_type,
|
||||
use_videos=len(source_dataset.meta.video_keys) > 0,
|
||||
vcodec=vcodec,
|
||||
)
|
||||
mirrored_dataset.meta.update_chunk_settings(
|
||||
chunks_size=source_dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=source_dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=source_dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Rank %s processing %s episodes into shard %s",
|
||||
rank,
|
||||
len(selected_episodes),
|
||||
shard_root,
|
||||
)
|
||||
for source_ep_idx in selected_episodes:
|
||||
episode = source_dataset.meta.episodes[source_ep_idx]
|
||||
start_idx = int(episode["dataset_from_index"])
|
||||
end_idx = int(episode["dataset_to_index"])
|
||||
|
||||
for frame_idx in range(start_idx, end_idx):
|
||||
item = source_dataset[frame_idx]
|
||||
mirrored_frame = build_mirrored_frame(
|
||||
item=item,
|
||||
source_features=source_dataset.meta.features,
|
||||
mirroring_mask=mirroring_mask,
|
||||
)
|
||||
mirrored_dataset.add_frame(mirrored_frame)
|
||||
|
||||
mirrored_dataset.save_episode()
|
||||
|
||||
mirrored_dataset.finalize()
|
||||
|
||||
|
||||
class MirrorDatasetShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
work_dir: Path,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.source_root = source_root
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.work_dir = work_dir
|
||||
self.vcodec = vcodec
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
shard_root = _get_shard_root(self.work_dir, world_size, rank)
|
||||
mirror_shard(
|
||||
repo_id=self.repo_id,
|
||||
source_root=self.source_root,
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
shard_root=shard_root,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
vcodec=self.vcodec,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
def make_mirror_executor(
|
||||
repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
work_dir: Path,
|
||||
logs_dir: Path,
|
||||
job_name: str,
|
||||
num_shards: int,
|
||||
workers: int,
|
||||
partition: str,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time_limit: str,
|
||||
vcodec: str,
|
||||
overwrite: bool,
|
||||
slurm: bool,
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
MirrorDatasetShards(
|
||||
repo_id=repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
work_dir=work_dir,
|
||||
vcodec=vcodec,
|
||||
overwrite=overwrite,
|
||||
),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
if partition is None:
|
||||
raise ValueError("`--partition` is required when `--slurm 1`.")
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": num_shards,
|
||||
"workers": workers,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": num_shards, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
class AggregateMirroredShardsStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
work_dir: Path,
|
||||
num_shards: int,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.mirrored_root = mirrored_root
|
||||
self.work_dir = work_dir
|
||||
self.num_shards = num_shards
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for aggregate mirrored step", rank)
|
||||
return
|
||||
aggregate_mirrored_shards(
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
mirrored_root=self.mirrored_root,
|
||||
work_dir=self.work_dir,
|
||||
num_shards=self.num_shards,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
class BuildDoubledDatasetStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
source_repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
overwrite: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.source_repo_id = source_repo_id
|
||||
self.source_root = source_root
|
||||
self.mirrored_repo_id = mirrored_repo_id
|
||||
self.mirrored_root = mirrored_root
|
||||
self.output_repo_id = output_repo_id
|
||||
self.output_root = output_root
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for build doubled step", rank)
|
||||
return
|
||||
build_doubled_dataset(
|
||||
source_repo_id=self.source_repo_id,
|
||||
source_root=self.source_root,
|
||||
mirrored_repo_id=self.mirrored_repo_id,
|
||||
mirrored_root=self.mirrored_root,
|
||||
output_repo_id=self.output_repo_id,
|
||||
output_root=self.output_root,
|
||||
overwrite=self.overwrite,
|
||||
)
|
||||
|
||||
|
||||
class PushDoubledDatasetStep(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_repo_id = output_repo_id
|
||||
self.output_root = output_root
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
init_logging()
|
||||
if rank != 0:
|
||||
logger.info("Skipping rank %s for push step", rank)
|
||||
return
|
||||
logger.info("Pushing doubled dataset to hub: %s", self.output_repo_id)
|
||||
LeRobotDataset(self.output_repo_id, root=self.output_root).push_to_hub()
|
||||
|
||||
|
||||
def make_single_task_executor(
|
||||
step: PipelineStep,
|
||||
logs_dir: Path,
|
||||
job_name: str,
|
||||
partition: str | None,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time_limit: str,
|
||||
slurm: bool,
|
||||
depends: SlurmPipelineExecutor | None = None,
|
||||
):
|
||||
kwargs = {"pipeline": [step], "logging_dir": str(logs_dir / job_name)}
|
||||
if slurm:
|
||||
if partition is None:
|
||||
raise ValueError("`--partition` is required when `--slurm 1`.")
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
"depends": depends,
|
||||
}
|
||||
)
|
||||
return SlurmPipelineExecutor(**kwargs)
|
||||
|
||||
kwargs.update({"tasks": 1, "workers": 1})
|
||||
return LocalPipelineExecutor(**kwargs)
|
||||
|
||||
|
||||
def aggregate_mirrored_shards(
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
work_dir: Path,
|
||||
num_shards: int,
|
||||
overwrite: bool,
|
||||
):
|
||||
if mirrored_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(mirrored_root)
|
||||
elif _is_valid_dataset_root(mirrored_root):
|
||||
logger.info("Mirrored dataset already exists at %s. Skipping aggregation.", mirrored_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Mirrored root {mirrored_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
shard_repo_ids = []
|
||||
shard_roots = []
|
||||
for rank in range(num_shards):
|
||||
shard_root = _get_shard_root(work_dir, num_shards, rank)
|
||||
if _is_valid_dataset_root(shard_root):
|
||||
shard_repo_ids.append(f"{mirrored_repo_id}_world_{num_shards}_rank_{rank}")
|
||||
shard_roots.append(shard_root)
|
||||
|
||||
if len(shard_repo_ids) == 0:
|
||||
raise RuntimeError("No mirrored shards were produced. Nothing to aggregate.")
|
||||
|
||||
logger.info("Aggregating %s mirrored shards into %s", len(shard_repo_ids), mirrored_root)
|
||||
aggregate_datasets(
|
||||
repo_ids=shard_repo_ids,
|
||||
roots=shard_roots,
|
||||
aggr_repo_id=mirrored_repo_id,
|
||||
aggr_root=mirrored_root,
|
||||
)
|
||||
|
||||
|
||||
def build_doubled_dataset(
|
||||
source_repo_id: str,
|
||||
source_root: Path,
|
||||
mirrored_repo_id: str,
|
||||
mirrored_root: Path,
|
||||
output_repo_id: str,
|
||||
output_root: Path,
|
||||
overwrite: bool,
|
||||
):
|
||||
if output_root.exists():
|
||||
if overwrite:
|
||||
shutil.rmtree(output_root)
|
||||
elif _is_valid_dataset_root(output_root):
|
||||
logger.info("Doubled dataset already exists at %s. Skipping final aggregation.", output_root)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Output root {output_root} exists but is not a valid dataset. Use --overwrite to recreate."
|
||||
)
|
||||
|
||||
logger.info("Aggregating source + mirrored into doubled dataset at %s", output_root)
|
||||
aggregate_datasets(
|
||||
repo_ids=[source_repo_id, mirrored_repo_id],
|
||||
roots=[source_root, mirrored_root],
|
||||
aggr_repo_id=output_repo_id,
|
||||
aggr_root=output_root,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repo id.")
|
||||
parser.add_argument("--output-repo-id", type=str, required=True, help="Final doubled dataset repo id.")
|
||||
parser.add_argument("--root", type=Path, default=None, help="Root path of source dataset.")
|
||||
parser.add_argument(
|
||||
"--output-root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root path where final doubled dataset is written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--work-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Intermediate directory for mirrored shards and mirrored aggregate dataset.",
|
||||
)
|
||||
parser.add_argument("--logs-dir", type=Path, required=True, help="DataTrove logs path.")
|
||||
parser.add_argument("--job-name", type=str, default="mirror_dataset", help="SLURM job name.")
|
||||
parser.add_argument("--num-shards", type=int, default=256, help="Number of DataTrove tasks/ranks.")
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Max concurrent DataTrove workers on SLURM.",
|
||||
)
|
||||
parser.add_argument("--partition", type=str, default=None, help="SLURM partition (e.g. hopper-cpu).")
|
||||
parser.add_argument("--cpus-per-task", type=int, default=8, help="CPU count per SLURM task.")
|
||||
parser.add_argument("--mem-per-cpu", type=str, default="4G", help="Memory per CPU for SLURM task.")
|
||||
parser.add_argument("--time", type=str, default="24:00:00", help="SLURM time limit.")
|
||||
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec for output videos.")
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Use SLURM executor. Set 0 for local sequential debugging.",
|
||||
)
|
||||
parser.add_argument("--overwrite", action="store_true", help="Delete existing intermediate/final outputs.")
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push final doubled dataset to Hugging Face Hub after completion.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
init_logging()
|
||||
slurm = args.slurm == 1
|
||||
|
||||
source_root = _resolve_source_root(args.repo_id, args.root)
|
||||
output_root = args.output_root if args.output_root is not None else HF_LEROBOT_HOME / args.output_repo_id
|
||||
work_dir = _get_work_dir(args.output_repo_id, args.work_dir)
|
||||
mirrored_repo_id = f"{args.output_repo_id}_mirrored"
|
||||
mirrored_root = work_dir / "mirrored_aggregate"
|
||||
|
||||
work_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mirror_executor = make_mirror_executor(
|
||||
repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
work_dir=work_dir,
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=args.job_name,
|
||||
num_shards=args.num_shards,
|
||||
workers=args.workers,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
vcodec=args.vcodec,
|
||||
overwrite=args.overwrite,
|
||||
slurm=slurm,
|
||||
)
|
||||
if slurm:
|
||||
aggregate_executor = make_single_task_executor(
|
||||
step=AggregateMirroredShardsStep(
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
work_dir=work_dir,
|
||||
num_shards=args.num_shards,
|
||||
overwrite=args.overwrite,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_aggregate_mirrored",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=mirror_executor,
|
||||
)
|
||||
build_executor = make_single_task_executor(
|
||||
step=BuildDoubledDatasetStep(
|
||||
source_repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
overwrite=args.overwrite,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_build_doubled",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=aggregate_executor,
|
||||
)
|
||||
|
||||
final_executor: SlurmPipelineExecutor | LocalPipelineExecutor = build_executor
|
||||
push_executor = None
|
||||
if args.push_to_hub:
|
||||
push_executor = make_single_task_executor(
|
||||
step=PushDoubledDatasetStep(
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
),
|
||||
logs_dir=args.logs_dir,
|
||||
job_name=f"{args.job_name}_push",
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time,
|
||||
slurm=True,
|
||||
depends=build_executor,
|
||||
)
|
||||
final_executor = push_executor
|
||||
|
||||
final_executor.run()
|
||||
logger.info(
|
||||
"Submitted SLURM chain. job_ids: mirror=%s aggregate=%s doubled=%s push=%s",
|
||||
mirror_executor.job_id,
|
||||
aggregate_executor.job_id,
|
||||
build_executor.job_id,
|
||||
push_executor.job_id if push_executor is not None else None,
|
||||
)
|
||||
return
|
||||
|
||||
mirror_executor.run()
|
||||
aggregate_mirrored_shards(
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
work_dir=work_dir,
|
||||
num_shards=args.num_shards,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
build_doubled_dataset(
|
||||
source_repo_id=args.repo_id,
|
||||
source_root=source_root,
|
||||
mirrored_repo_id=mirrored_repo_id,
|
||||
mirrored_root=mirrored_root,
|
||||
output_repo_id=args.output_repo_id,
|
||||
output_root=output_root,
|
||||
overwrite=args.overwrite,
|
||||
)
|
||||
if args.push_to_hub:
|
||||
logger.info("Pushing doubled dataset to hub: %s", args.output_repo_id)
|
||||
LeRobotDataset(args.output_repo_id, root=output_root).push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,8 +27,8 @@ measuring consistency and ground truth alignment.
|
||||
Usage:
|
||||
# Basic usage with smolvla policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
@@ -58,16 +58,16 @@ Usage:
|
||||
--device=cuda
|
||||
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lipsop/reuben_pi0 \
|
||||
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
|
||||
--policy.path=<USER>/reuben_pi0 \
|
||||
--dataset.repo_id=<USER>/so101_cube_in_cup \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
@@ -75,8 +75,8 @@ Usage:
|
||||
|
||||
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
@@ -84,8 +84,8 @@ Usage:
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=<USER>/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
|
||||
@@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
@@ -41,7 +41,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
@@ -53,7 +53,7 @@ Usage:
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
|
||||
+4
-4
@@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
@@ -76,9 +76,9 @@ dependencies = [
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=1.1.1,<2.0.0",
|
||||
|
||||
@@ -150,7 +150,7 @@ class Camera(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -530,7 +530,7 @@ class OpenCVCamera(Camera):
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -201,7 +201,7 @@ class Reachy2Camera(Camera):
|
||||
return self.read()
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -573,7 +573,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
|
||||
@@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
|
||||
@@ -122,19 +122,9 @@ def load_nested_dataset(
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
with SuppressProgressBars():
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
# We use .from_parquet() memory-mapped loading for efficiency
|
||||
filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None
|
||||
return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
|
||||
@@ -529,7 +529,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
"(e.g. `lerobot/pusht`, `<USER>/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
|
||||
@@ -27,18 +27,18 @@ Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
@@ -714,12 +714,12 @@ Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
|
||||
@@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
@@ -40,7 +40,7 @@ and an action expert.
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--dataset.repo_id=<USER>/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
@@ -44,6 +44,7 @@ from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
@@ -87,6 +88,7 @@ __all__ = [
|
||||
"DoneProcessorStep",
|
||||
"EnvAction",
|
||||
"EnvTransition",
|
||||
"GymHILAdapterProcessorStep",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .core import EnvAction, EnvTransition, PolicyAction
|
||||
from .hil_processor import TELEOP_ACTION_KEY
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@@ -89,6 +90,13 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
new_transition[TransitionKey.ACTION] = torch_action
|
||||
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if TELEOP_ACTION_KEY in complementary_data:
|
||||
teleop_action = complementary_data[TELEOP_ACTION_KEY]
|
||||
if isinstance(teleop_action, EnvAction):
|
||||
complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -312,6 +312,37 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gym_hil_adapter_processor")
|
||||
class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors.
|
||||
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
transition[TransitionKey.INFO] = info
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
|
||||
@@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]):
|
||||
Args:
|
||||
save_directory: The directory where the pipeline will be saved. If None, saves to
|
||||
HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}.
|
||||
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`.
|
||||
repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`.
|
||||
push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it.
|
||||
card_kwargs: Additional arguments passed to the card template to customize the card.
|
||||
config_filename: The name of the JSON configuration file. If None, a name is
|
||||
|
||||
@@ -36,6 +36,7 @@ from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
GymHILAdapterProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
@@ -379,6 +380,7 @@ def make_processors(
|
||||
]
|
||||
|
||||
env_pipeline_steps = [
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -608,7 +610,14 @@ def control_loop(
|
||||
|
||||
dataset = None
|
||||
if cfg.mode == "record":
|
||||
action_features = teleop_device.action_features
|
||||
if teleop_device:
|
||||
action_features = teleop_device.action_features
|
||||
else:
|
||||
action_features = {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["delta_x", "delta_y", "delta_z", "gripper"],
|
||||
}
|
||||
features = {
|
||||
ACTION: action_features,
|
||||
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
|
||||
@@ -656,7 +665,7 @@ def control_loop(
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
@@ -725,6 +734,8 @@ def control_loop(
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
dataset.finalize()
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class HopeJrArm(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -171,7 +171,7 @@ class HopeJrHand(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -193,7 +193,7 @@ class KochFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ class LeKiwi(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class OmxFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ class OpenArmFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ class Reachy2Robot(Robot):
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class SOFollowerConfig:
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
use_degrees: bool = True
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so101_follower")
|
||||
|
||||
@@ -187,7 +187,7 @@ class SOFollower(Robot):
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
obs[cam_name] = cam.async_read()
|
||||
obs[cam_name] = cam.read_latest()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
@@ -47,16 +47,14 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine through streaming:
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
--ws-port 9087
|
||||
--grpc-port 9876
|
||||
|
||||
local$ rerun ws://localhost:9087
|
||||
local$ rerun rerun+http://IP:GRPC_PORT/proxy
|
||||
```
|
||||
|
||||
"""
|
||||
@@ -75,6 +73,7 @@ import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
@@ -93,10 +92,11 @@ def visualize_dataset(
|
||||
num_workers: int = 0,
|
||||
mode: str = "local",
|
||||
web_port: int = 9090,
|
||||
ws_port: int = 9087,
|
||||
grpc_port: int = 9876,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
**kwargs,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert output_dir is not None, (
|
||||
@@ -126,7 +126,9 @@ def visualize_dataset(
|
||||
gc.collect()
|
||||
|
||||
if mode == "distant":
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port)
|
||||
server_uri = rr.serve_grpc(grpc_port=grpc_port)
|
||||
logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy")
|
||||
rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri)
|
||||
|
||||
logging.info("Logging to Rerun")
|
||||
|
||||
@@ -226,7 +228,7 @@ def main():
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
"Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -238,8 +240,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--ws-port",
|
||||
type=int,
|
||||
default=9087,
|
||||
help="Web socket port for rerun.io when `--mode distant` is set.",
|
||||
help="deprecated, please use --grpc-port instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grpc-port",
|
||||
type=int,
|
||||
default=9876,
|
||||
help="gRPC port for rerun.io when `--mode distant` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
@@ -265,9 +272,7 @@ def main():
|
||||
|
||||
parser.add_argument(
|
||||
"--display-compressed-images",
|
||||
type=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="If set, display compressed images in Rerun instead of uncompressed ones.",
|
||||
)
|
||||
|
||||
@@ -277,6 +282,14 @@ def main():
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
if kwargs["ws_port"] is not None:
|
||||
logging.warning(
|
||||
"--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead."
|
||||
)
|
||||
logging.warning("Setting grpc_port to ws_port value.")
|
||||
kwargs["grpc_port"] = kwargs.pop("ws_port")
|
||||
|
||||
init_logging()
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
|
||||
@@ -24,96 +24,112 @@ When new_repo_id is specified, creates a new dataset.
|
||||
Usage Examples:
|
||||
|
||||
Delete episodes 0, 2, and 5 from a dataset:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and save to a new dataset:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_filtered \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Split dataset by fractions:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||
|
||||
Split dataset by episode indices:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
|
||||
|
||||
Split into more than two splits:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Merge multiple datasets:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_merged \
|
||||
--operation.type merge \
|
||||
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||
|
||||
Remove camera feature:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset to video format and save with new repo_id:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video
|
||||
|
||||
Convert image dataset to video format and push to hub:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features true
|
||||
|
||||
Show dataset information without feature details:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Using JSON config file:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
lerobot-edit-dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
convert_image_to_video_dataset,
|
||||
@@ -129,39 +145,46 @@ from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteEpisodesConfig:
|
||||
type: str = "delete_episodes"
|
||||
class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("delete_episodes")
|
||||
@dataclass
|
||||
class DeleteEpisodesConfig(OperationConfig):
|
||||
episode_indices: list[int] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("split")
|
||||
@dataclass
|
||||
class SplitConfig:
|
||||
type: str = "split"
|
||||
class SplitConfig(OperationConfig):
|
||||
splits: dict[str, float | list[int]] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("merge")
|
||||
@dataclass
|
||||
class MergeConfig:
|
||||
type: str = "merge"
|
||||
class MergeConfig(OperationConfig):
|
||||
repo_ids: list[str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("remove_feature")
|
||||
@dataclass
|
||||
class RemoveFeatureConfig:
|
||||
type: str = "remove_feature"
|
||||
class RemoveFeatureConfig(OperationConfig):
|
||||
feature_names: list[str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("modify_tasks")
|
||||
@dataclass
|
||||
class ModifyTasksConfig:
|
||||
type: str = "modify_tasks"
|
||||
class ModifyTasksConfig(OperationConfig):
|
||||
new_task: str | None = None
|
||||
episode_tasks: dict[str, str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("convert_image_to_video")
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig:
|
||||
type: str = "convert_image_to_video"
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
output_dir: str | None = None
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
@@ -174,17 +197,17 @@ class ConvertImageToVideoConfig:
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@dataclass
|
||||
class InfoConfig(OperationConfig):
|
||||
type: str = "info"
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: (
|
||||
DeleteEpisodesConfig
|
||||
| SplitConfig
|
||||
| MergeConfig
|
||||
| RemoveFeatureConfig
|
||||
| ModifyTasksConfig
|
||||
| ConvertImageToVideoConfig
|
||||
)
|
||||
operation: OperationConfig
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
push_to_hub: bool = False
|
||||
@@ -433,6 +456,49 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
def _get_dataset_size(repo_path):
|
||||
import os
|
||||
|
||||
total = 0
|
||||
with os.scandir(repo_path) as it:
|
||||
for entry in it:
|
||||
if entry.is_file():
|
||||
total += entry.stat().st_size
|
||||
elif entry.is_dir():
|
||||
total += _get_dataset_size(entry.path)
|
||||
return total
|
||||
|
||||
|
||||
def handle_info(cfg: EditDatasetConfig):
|
||||
if not isinstance(cfg.operation, InfoConfig):
|
||||
raise ValueError("Operation config must be InfoConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
sys.stdout.write(f"======Info {dataset.meta.repo_id}\n")
|
||||
sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n")
|
||||
sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n")
|
||||
sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n")
|
||||
sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n")
|
||||
sys.stdout.write(
|
||||
f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(
|
||||
f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n"
|
||||
)
|
||||
sys.stdout.write(f"FPS: {dataset.meta.fps}\n")
|
||||
|
||||
total_file_size = _get_dataset_size(dataset.root)
|
||||
sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n")
|
||||
if cfg.operation.show_features:
|
||||
import json
|
||||
|
||||
feature_dump_str = json.dumps(
|
||||
dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ")
|
||||
)
|
||||
sys.stdout.write("Features:\n")
|
||||
sys.stdout.write(f"{feature_dump_str}\n")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
operation_type = cfg.operation.type
|
||||
@@ -449,11 +515,11 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
|
||||
)
|
||||
available = ", ".join(OperationConfig.get_known_choices())
|
||||
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@@ -398,7 +398,14 @@ def record_loop(
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ lerobot-replay \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=aliberts/record-test \
|
||||
--dataset.repo_id=<USER>/record-test \
|
||||
--dataset.episode=0
|
||||
```
|
||||
|
||||
|
||||
@@ -166,9 +166,9 @@ def apply_normalization(
|
||||
if q01 is None or q99 is None:
|
||||
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
|
||||
denom = np.maximum(q99 - q01, eps)
|
||||
# No clipping: match training pipeline NormalizerProcessorStep so tokenizer
|
||||
# is fit on the full range of normalized values (including tails outside [-1, 1]).
|
||||
return 2.0 * (data - q01) / denom - 1.0
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q01, q99)
|
||||
return 2.0 * (clipped - q01) / denom - 1.0
|
||||
|
||||
if mode == NormalizationMode.QUANTILE10:
|
||||
q10 = stats.get("q10")
|
||||
@@ -176,8 +176,9 @@ def apply_normalization(
|
||||
if q10 is None or q90 is None:
|
||||
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
|
||||
denom = np.maximum(q90 - q10, eps)
|
||||
# No clipping: match training pipeline NormalizerProcessorStep.
|
||||
return 2.0 * (data - q10) / denom - 1.0
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q10, q90)
|
||||
return 2.0 * (clipped - q10) / denom - 1.0
|
||||
|
||||
raise ValueError(f"Unsupported normalization mode: {mode}")
|
||||
|
||||
@@ -305,7 +306,7 @@ def train_fast_tokenizer(
|
||||
|
||||
# download the tokenizer source code (not pretrained weights)
|
||||
# we'll train a new tokenizer on our own data
|
||||
base_tokenizer = AutoProcessor.from_pretrained("/fsx/jade_choghari/outputs/libero_tokenizer_wavetoken1", trust_remote_code=True)
|
||||
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
|
||||
# convert action_chunks array to list of arrays (expected by .fit())
|
||||
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
||||
@@ -319,8 +320,6 @@ def train_fast_tokenizer(
|
||||
vocab_size=vocab_size,
|
||||
time_horizon=action_chunks.shape[1], # action_horizon
|
||||
action_dim=action_chunks.shape[2], # encoded dimensions
|
||||
wavelet="dmey",
|
||||
level=1,
|
||||
)
|
||||
print("✓ Tokenizer training complete!")
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class SOLeaderConfig:
|
||||
port: str
|
||||
|
||||
# Whether to use degrees for angles
|
||||
use_degrees: bool = False
|
||||
use_degrees: bool = True
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("so101_leader")
|
||||
|
||||
@@ -16,14 +16,14 @@ import platform
|
||||
import time
|
||||
|
||||
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005):
|
||||
"""
|
||||
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
|
||||
|
||||
Parameters:
|
||||
- seconds: duration to wait
|
||||
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms
|
||||
|
||||
Note:
|
||||
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.optim.schedulers import (
|
||||
@@ -38,6 +40,10 @@ def test_diffuser_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -56,6 +62,10 @@ def test_vqbet_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@@ -76,6 +86,10 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
}
|
||||
|
||||
if Version(torch.__version__) >= Version("2.8"):
|
||||
expected_state_dict["_is_initial"] = False
|
||||
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -142,6 +142,7 @@ def _make_reachy2_camera_mock(*args, **kwargs):
|
||||
cam.connect = MagicMock()
|
||||
cam.disconnect = MagicMock()
|
||||
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
return cam
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.lerobot_edit_dataset import (
|
||||
ConvertImageToVideoConfig,
|
||||
DeleteEpisodesConfig,
|
||||
EditDatasetConfig,
|
||||
InfoConfig,
|
||||
MergeConfig,
|
||||
ModifyTasksConfig,
|
||||
OperationConfig,
|
||||
RemoveFeatureConfig,
|
||||
SplitConfig,
|
||||
)
|
||||
|
||||
|
||||
def parse_cfg(cli_args: list[str]) -> EditDatasetConfig:
|
||||
"""Helper to parse CLI args into an EditDatasetConfig via draccus."""
|
||||
return draccus.parse(EditDatasetConfig, args=cli_args)
|
||||
|
||||
|
||||
class TestOperationTypeParsing:
|
||||
"""Test that --operation.type correctly selects the right config subclass."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"type_name, expected_cls",
|
||||
[
|
||||
("delete_episodes", DeleteEpisodesConfig),
|
||||
("split", SplitConfig),
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
|
||||
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
|
||||
assert isinstance(cfg.operation, expected_cls), (
|
||||
f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"type_name, expected_cls",
|
||||
[
|
||||
("delete_episodes", DeleteEpisodesConfig),
|
||||
("split", SplitConfig),
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
)
|
||||
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
|
||||
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
|
||||
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
|
||||
assert resolved_name == type_name
|
||||
Reference in New Issue
Block a user