mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 07:39:53 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 32fc5504cc | |||
| fc8a388a25 | |||
| 3c84d271d5 | |||
| 1ba3975020 | |||
| 35363c5798 | |||
| 778db19a17 | |||
| d2d01399d6 | |||
| 5eba4ce6f4 | |||
| cca0296cd6 | |||
| 489cb7b6b9 | |||
| e14bdf57d0 | |||
| 97e7e0f9ed | |||
| 0f39248445 | |||
| a6370dd783 | |||
| 14a15f90e7 | |||
| 9c24a09665 | |||
| b18cef2e26 | |||
| 5c6182176f | |||
| 55c0471db9 | |||
| ec04b7ce3a | |||
| 04cbf669cf | |||
| 3409ef0dc2 | |||
| 4483184875 | |||
| 149628dfd5 | |||
| bf337e716d | |||
| 736b43f3cf |
@@ -101,9 +101,11 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: |
|
||||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
|
||||
github.event_name == 'push' ||
|
||||
github.event_name == 'workflow_dispatch'
|
||||
github.repository == 'huggingface/lerobot' && (
|
||||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
|
||||
github.event_name == 'push' ||
|
||||
github.event_name == 'workflow_dispatch'
|
||||
)
|
||||
outputs:
|
||||
image_tag: ${{ steps.set_tag.outputs.image_tag }}
|
||||
env:
|
||||
|
||||
@@ -91,6 +91,7 @@ jobs:
|
||||
name: Build and Push Docker
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
env:
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
# Action tokenizer benchmark
|
||||
|
||||
## Questions
|
||||
|
||||
What is the trade-off between:
|
||||
|
||||
- **Compression**: how many tokens are needed to represent an action chunk (e.g. horizon × action_dim floats)?
|
||||
- **Reconstruction quality**: how well does encode-then-decode preserve the original actions?
|
||||
- **Speed**: how long does encoding and decoding take per chunk?
|
||||
|
||||
How to choose an action tokenizer?
|
||||
|
||||
- Which tokenizer architecture (e.g. dct + BPE, DCT + BPE)?
|
||||
- Which **action horizon** and **encoded dimensions** to use?
|
||||
- Which **normalization** (QUANTILES, MEAN_STD, MIN_MAX) and **delta transform** (relative vs absolute actions)?
|
||||
- How do reconstruction error and compression ratio vary across datasets and tokenizer settings?
|
||||
|
||||
This benchmark loads action chunks from a LeRobot dataset using the same pipeline as `lerobot-train-tokenizer`, runs a trained action tokenizer in encode/decode mode, and reports reconstruction error, compression stats, and timing. Results are saved as JSON under `outputs/` for comparison and analysis.
|
||||
|
||||
## Variables
|
||||
|
||||
**Dataset & chunking**
|
||||
|
||||
- **repo_id**: LeRobot dataset (e.g. `lerobot/pusht`). Action statistics and normalization are taken from the dataset metadata when available.
|
||||
- **action_horizon**: Number of future steps per action chunk (must match the tokenizer’s training).
|
||||
- **encoded_dims**: Dimension ranges to encode (e.g. `0:6` or `0:6,7:14`). Must match the tokenizer.
|
||||
- **max_episodes**: Cap on episodes to load (default: all).
|
||||
- **sample_fraction**: Fraction of chunks to sample per episode (default `0.2`) to keep runtime manageable.
|
||||
|
||||
**Transform & normalization**
|
||||
|
||||
- **normalization_mode**: `IDENTITY`, `MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`. Should match the tokenizer’s training.
|
||||
- **delta_dims**: Comma-separated dimension indices for delta (relative) transform.
|
||||
- **use_delta_transform**: Whether to convert actions to relative to current state for those dimensions.
|
||||
- **state_key**: Dataset key for state (e.g. `observation.state`) used when applying delta transform.
|
||||
|
||||
**Tokenizer & evaluation**
|
||||
|
||||
- **action_tokenizer_path**: Path or HuggingFace repo id of the trained tokenizer (e.g. `outputs/wavetoken`).
|
||||
- **max_chunks_for_reconstruction**: Max number of chunks to use for reconstruction and timing (default `500`) to limit runtime.
|
||||
|
||||
### Main parameters
|
||||
|
||||
| parameter | default | description |
|
||||
| -------------------------------- | ---------------------------- | ------------------------------------------------ |
|
||||
| **action_tokenizer_path** | (required) | Path or Hub id of the trained action tokenizer. |
|
||||
| **repo_id** | (required) | LeRobot dataset repo id. |
|
||||
| **action_horizon** | `10` | Future steps per chunk. |
|
||||
| **encoded_dims** | `0:6` | Dimension ranges to encode (e.g. `0:6,7:14`). |
|
||||
| **normalization_mode** | `QUANTILES` | Normalization mode for actions. |
|
||||
| **max_episodes** | all | Max episodes to load. |
|
||||
| **sample_fraction** | `0.2` | Fraction of chunks sampled per episode. |
|
||||
| **max_chunks_for_reconstruction**| `500` | Chunks used for reconstruction and timing. |
|
||||
| **output_dir** | `outputs/action_tokenizer_benchmark` | Directory for results JSON. |
|
||||
|
||||
## Metrics
|
||||
|
||||
**Reconstruction (lower is better)**
|
||||
|
||||
- **reconstruction_mae**: Mean absolute error between original and decoded action chunks.
|
||||
- **reconstruction_mse**: Mean squared error.
|
||||
- **reconstruction_rmse**: Root mean squared error.
|
||||
- **reconstruction_max_abs_error**: Maximum absolute error over all dimensions and samples.
|
||||
- **per_dimension_mae**: MAE per action dimension (list of length `action_dim`).
|
||||
|
||||
**Compression**
|
||||
|
||||
- **compression_ratio**: Ratio (action_horizon × action_dim) / mean number of tokens. Higher means more compression.
|
||||
- **mean_token_length**, **std_token_length**: Mean and standard deviation of token count per chunk.
|
||||
- **min_token_length**, **max_token_length**: Min and max token count.
|
||||
- **p50_token_length**, **p99_token_length**: 50th and 99th percentile token counts.
|
||||
|
||||
**Timing (seconds per chunk)**
|
||||
|
||||
- **mean_encode_time_sec**: Mean time to encode one chunk.
|
||||
- **mean_decode_time_sec**: Mean time to decode one chunk.
|
||||
|
||||
The JSON output also includes **num_chunks_evaluated** and **total_chunks_available** for context.
|
||||
|
||||
## How the benchmark works
|
||||
|
||||
1. **Load dataset**: LeRobot dataset is loaded for the given `repo_id` and `root`.
|
||||
2. **Build action chunks**: For each episode (up to `max_episodes`), action chunks are built with the same logic as `lerobot-train-tokenizer`: sliding window of length `action_horizon`, optional delta transform, and per-episode sampling with `sample_fraction`.
|
||||
3. **Extract and normalize**: Only `encoded_dims` are kept. Normalization is applied using the dataset’s action stats when available, according to `normalization_mode`.
|
||||
4. **Encode / decode**: A random sample of chunks (size `max_chunks_for_reconstruction`) is encoded and then decoded with the tokenizer. Encode and decode times are recorded per chunk.
|
||||
5. **Compute metrics**: Reconstruction metrics are computed between original and decoded chunks; compression and timing stats are aggregated.
|
||||
6. **Save results**: A JSON file is written to `output_dir` with name `{timestamp}_{repo_id}_action_tokenizer_results.json`, containing the full config and all metrics.
|
||||
|
||||
The pipeline (chunking, dimensions, normalization, delta) must match how the tokenizer was trained; otherwise reconstruction error can be large or the tokenizer may raise.
|
||||
|
||||
## Caveats
|
||||
|
||||
- The tokenizer’s **action_horizon** and **action_dim** (and optionally DCT settings) are fixed at training time. The benchmark infers dimensions from the dataset and encoded dims; the tokenizer path must correspond to a model trained with the same horizon and encoded dimensions.
|
||||
- Reconstruction is evaluated in **normalized space** (the same space the tokenizer sees). For interpretation in raw action space, you would need to invert normalization outside this script.
|
||||
- Only one tokenizer and one dataset are evaluated per run. To compare tokenizers or datasets, run the script multiple times and compare the saved JSON files.
|
||||
|
||||
## Example
|
||||
|
||||
Quick run with a local tokenizer and a small number of episodes:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--max-episodes=50 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
|
||||
With delta transform and custom encoded dimensions:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--encoded-dims=0:6,7:14 \
|
||||
--delta-dims=0,1,2,3,4,5 \
|
||||
--use-delta-transform \
|
||||
--normalization-mode=QUANTILES \
|
||||
--max-chunks-for-reconstruction=500 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
|
||||
Results are written to e.g. `outputs/action_tokenizer_benchmark/2026-02-12_14-30-00_lerobot_pusht_action_tokenizer_results.json`.
|
||||
|
||||
## Results
|
||||
|
||||
Results are stored as JSON in the directory given by `--output-dir` (default: `outputs/action_tokenizer_benchmark`). Each file contains:
|
||||
|
||||
- **config**: All script arguments (tokenizer path, repo_id, action_horizon, encoded_dims, normalization_mode, etc.) for reproducibility.
|
||||
- **metrics**: All reconstruction, compression, and timing metrics described above.
|
||||
|
||||
To compare runs, load and diff or aggregate these JSON files with your own scripts or notebooks.
|
||||
@@ -0,0 +1,442 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Benchmark action tokenization: reconstruction error, compression ratio, and timing.
|
||||
|
||||
Loads action chunks from a LeRobot dataset, encodes/decodes them with a trained action
|
||||
tokenizer, and reports:
|
||||
- Reconstruction: MAE, MSE, RMSE, max absolute error, per-dimension MAE
|
||||
- Jerk: mean absolute jerk (original and reconstructed), jerk reconstruction MAE
|
||||
- Compression: ratio (input size / mean tokens), token length stats
|
||||
- Timing: mean encode/decode time per chunk
|
||||
|
||||
Results are saved to outputs/action_tokenizer_benchmark/<timestamp>_results.json.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--max-episodes=50 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
# Optional: use same helpers as train script if we want to avoid duplication
|
||||
from lerobot.scripts.lerobot_train_tokenizer import (
|
||||
apply_normalization,
|
||||
process_episode,
|
||||
)
|
||||
|
||||
|
||||
def load_action_chunks(
|
||||
repo_id: str,
|
||||
root: str | None,
|
||||
action_horizon: int,
|
||||
max_episodes: int | None,
|
||||
sample_fraction: float,
|
||||
encoded_dims: str,
|
||||
delta_dims: str | None,
|
||||
use_delta_transform: bool,
|
||||
state_key: str,
|
||||
normalization_mode: NormalizationMode,
|
||||
):
|
||||
"""Load and normalize action chunks from a LeRobot dataset (same pipeline as training)."""
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
||||
num_episodes = dataset.num_episodes
|
||||
if max_episodes is not None:
|
||||
num_episodes = min(max_episodes, num_episodes)
|
||||
|
||||
# Parse encoded dims
|
||||
encoded_dim_ranges = []
|
||||
for range_str in encoded_dims.split(","):
|
||||
start, end = map(int, range_str.strip().split(":"))
|
||||
encoded_dim_ranges.append((start, end))
|
||||
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
||||
|
||||
delta_dim_list = None
|
||||
if delta_dims is not None and delta_dims.strip():
|
||||
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")]
|
||||
|
||||
all_chunks = []
|
||||
for ep_idx in range(num_episodes):
|
||||
chunks = process_episode(
|
||||
(
|
||||
dataset,
|
||||
ep_idx,
|
||||
action_horizon,
|
||||
delta_dim_list,
|
||||
sample_fraction,
|
||||
state_key,
|
||||
use_delta_transform,
|
||||
)
|
||||
)
|
||||
if chunks is not None:
|
||||
all_chunks.append(chunks)
|
||||
|
||||
if not all_chunks:
|
||||
raise ValueError("No action chunks collected. Check action_horizon and dataset.")
|
||||
|
||||
all_chunks = np.concatenate(all_chunks, axis=0)
|
||||
|
||||
# Extract encoded dimensions only
|
||||
encoded_chunks = []
|
||||
for start, end in encoded_dim_ranges:
|
||||
encoded_chunks.append(all_chunks[:, :, start:end])
|
||||
encoded_chunks = np.concatenate(encoded_chunks, axis=-1)
|
||||
|
||||
# Normalize
|
||||
norm_stats = dataset.meta.stats
|
||||
if norm_stats is not None and ACTION in norm_stats:
|
||||
action_stats = norm_stats[ACTION]
|
||||
encoded_dim_indices = []
|
||||
for start, end in encoded_dim_ranges:
|
||||
encoded_dim_indices.extend(range(start, end))
|
||||
encoded_dim_indices = np.array(encoded_dim_indices)
|
||||
encoded_stats = {}
|
||||
for stat_name, stat_values in action_stats.items():
|
||||
if isinstance(stat_values, (list, np.ndarray)):
|
||||
stat_array = np.array(stat_values)
|
||||
if len(stat_array) > max(encoded_dim_indices):
|
||||
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
|
||||
if encoded_stats:
|
||||
try:
|
||||
encoded_chunks = apply_normalization(
|
||||
encoded_chunks, encoded_stats, normalization_mode, eps=1e-8
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return encoded_chunks, total_encoded_dims, action_horizon, dataset.repo_id
|
||||
|
||||
|
||||
def compute_reconstruction_metrics(original: np.ndarray, reconstructed: np.ndarray):
|
||||
"""Compute reconstruction error metrics (original and reconstructed same shape [N, T, D])."""
|
||||
diff = reconstructed - original
|
||||
mae = float(np.mean(np.abs(diff)))
|
||||
mse = float(np.mean(diff**2))
|
||||
rmse = float(np.sqrt(mse))
|
||||
max_abs_err = float(np.max(np.abs(diff)))
|
||||
|
||||
# Per-dimension MAE (over N and T)
|
||||
per_dim_mae = np.mean(np.abs(diff), axis=(0, 1))
|
||||
per_dim_mae = per_dim_mae.tolist()
|
||||
|
||||
return {
|
||||
"reconstruction_mae": mae,
|
||||
"reconstruction_mse": mse,
|
||||
"reconstruction_rmse": rmse,
|
||||
"reconstruction_max_abs_error": max_abs_err,
|
||||
"per_dimension_mae": per_dim_mae,
|
||||
}
|
||||
|
||||
|
||||
def compute_jerk_metrics(original: np.ndarray, reconstructed: np.ndarray) -> dict:
|
||||
"""Compute jerk (3rd derivative of action w.r.t. time) metrics.
|
||||
|
||||
Args:
|
||||
original: Action chunks [N, T, D].
|
||||
reconstructed: Reconstructed action chunks [N, T, D].
|
||||
|
||||
Returns:
|
||||
Dict with mean absolute jerk for original, reconstructed, and jerk reconstruction MAE.
|
||||
"""
|
||||
# Jerk = 3rd discrete difference along time axis; need T >= 4
|
||||
if original.shape[1] < 4:
|
||||
return {}
|
||||
jerk_orig = np.diff(original, n=3, axis=1) # (N, T-3, D)
|
||||
jerk_recon = np.diff(reconstructed, n=3, axis=1)
|
||||
mae_jerk_orig = float(np.mean(np.abs(jerk_orig)))
|
||||
mae_jerk_recon = float(np.mean(np.abs(jerk_recon)))
|
||||
jerk_reconstruction_mae = float(np.mean(np.abs(jerk_recon - jerk_orig)))
|
||||
return {
|
||||
"jerk_mae_original": mae_jerk_orig,
|
||||
"jerk_mae_reconstructed": mae_jerk_recon,
|
||||
"jerk_reconstruction_mae": jerk_reconstruction_mae,
|
||||
}
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
action_chunks: np.ndarray,
|
||||
action_horizon: int,
|
||||
action_dim: int,
|
||||
tokenizer_path: str,
|
||||
max_chunks_for_reconstruction: int | None = 500,
|
||||
):
|
||||
"""Encode/decode action chunks and compute metrics."""
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
|
||||
|
||||
n_chunks = len(action_chunks)
|
||||
sample_size = n_chunks
|
||||
if max_chunks_for_reconstruction is not None:
|
||||
sample_size = min(max_chunks_for_reconstruction, n_chunks)
|
||||
rng = np.random.RandomState(42)
|
||||
indices = rng.choice(n_chunks, size=sample_size, replace=False)
|
||||
sample_chunks = action_chunks[indices]
|
||||
|
||||
# Encode
|
||||
token_lengths = []
|
||||
encode_times = []
|
||||
all_tokens = []
|
||||
for i in range(len(sample_chunks)):
|
||||
chunk = sample_chunks[i : i + 1]
|
||||
t0 = time.perf_counter()
|
||||
tokens = processor(chunk)[0]
|
||||
encode_times.append(time.perf_counter() - t0)
|
||||
if isinstance(tokens, list):
|
||||
token_lengths.append(len(tokens))
|
||||
all_tokens.append(tokens)
|
||||
else:
|
||||
n = tokens.shape[0] if hasattr(tokens, "shape") else len(tokens)
|
||||
token_lengths.append(n)
|
||||
all_tokens.append(tokens.tolist() if hasattr(tokens, "tolist") else list(tokens))
|
||||
|
||||
# Decode (processor keeps time_horizon/action_dim from encode)
|
||||
decoded_list = []
|
||||
decode_times = []
|
||||
for i, tok_list in enumerate(all_tokens):
|
||||
t0 = time.perf_counter()
|
||||
recon = processor.decode(
|
||||
[tok_list],
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
decode_times.append(time.perf_counter() - t0)
|
||||
decoded_list.append(recon)
|
||||
decoded = np.concatenate(decoded_list, axis=0)
|
||||
|
||||
# Reconstruction metrics
|
||||
metrics = compute_reconstruction_metrics(sample_chunks, decoded)
|
||||
|
||||
# Jerk metrics (3rd derivative along time)
|
||||
jerk_metrics = compute_jerk_metrics(sample_chunks, decoded)
|
||||
metrics.update(jerk_metrics)
|
||||
|
||||
# Compression
|
||||
token_lengths = np.array(token_lengths)
|
||||
input_size = action_horizon * action_dim
|
||||
compression_ratio = input_size / float(np.mean(token_lengths))
|
||||
metrics["compression_ratio"] = compression_ratio
|
||||
metrics["mean_token_length"] = float(np.mean(token_lengths))
|
||||
metrics["std_token_length"] = float(np.std(token_lengths))
|
||||
metrics["min_token_length"] = int(np.min(token_lengths))
|
||||
metrics["max_token_length"] = int(np.max(token_lengths))
|
||||
metrics["p50_token_length"] = float(np.percentile(token_lengths, 50))
|
||||
metrics["p99_token_length"] = float(np.percentile(token_lengths, 99))
|
||||
|
||||
# Timing (seconds per chunk)
|
||||
metrics["mean_encode_time_sec"] = float(np.mean(encode_times))
|
||||
metrics["mean_decode_time_sec"] = float(np.mean(decode_times))
|
||||
metrics["num_chunks_evaluated"] = sample_size
|
||||
metrics["total_chunks_available"] = n_chunks
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main(
|
||||
action_tokenizer_path: str,
|
||||
repo_id: str,
|
||||
root: str | None = None,
|
||||
action_horizon: int = 10,
|
||||
max_episodes: int | None = 100,
|
||||
sample_fraction: float = 0.2,
|
||||
encoded_dims: str = "0:6",
|
||||
delta_dims: str | None = None,
|
||||
use_delta_transform: bool = False,
|
||||
state_key: str = OBS_STATE,
|
||||
normalization_mode: str = "QUANTILES",
|
||||
max_chunks_for_reconstruction: int | None = 500,
|
||||
output_dir: str | None = None,
|
||||
):
|
||||
if output_dir is None:
|
||||
output_dir = "outputs/action_tokenizer_benchmark"
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
norm_mode = NormalizationMode(normalization_mode)
|
||||
except ValueError:
|
||||
norm_mode = NormalizationMode.QUANTILES
|
||||
|
||||
print("Loading action chunks...")
|
||||
encoded_chunks, action_dim, horizon, _ = load_action_chunks(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
action_horizon=action_horizon,
|
||||
max_episodes=max_episodes,
|
||||
sample_fraction=sample_fraction,
|
||||
encoded_dims=encoded_dims,
|
||||
delta_dims=delta_dims,
|
||||
use_delta_transform=use_delta_transform,
|
||||
state_key=state_key,
|
||||
normalization_mode=norm_mode,
|
||||
)
|
||||
print(f"Loaded {len(encoded_chunks)} chunks, shape {encoded_chunks.shape} (H={horizon}, D={action_dim})")
|
||||
|
||||
print("Running tokenizer benchmark...")
|
||||
metrics = run_benchmark(
|
||||
action_chunks=encoded_chunks,
|
||||
action_horizon=horizon,
|
||||
action_dim=action_dim,
|
||||
tokenizer_path=action_tokenizer_path,
|
||||
max_chunks_for_reconstruction=max_chunks_for_reconstruction,
|
||||
)
|
||||
|
||||
# Attach config for reproducibility
|
||||
results = {
|
||||
"config": {
|
||||
"action_tokenizer_path": action_tokenizer_path,
|
||||
"repo_id": repo_id,
|
||||
"action_horizon": action_horizon,
|
||||
"max_episodes": max_episodes,
|
||||
"sample_fraction": sample_fraction,
|
||||
"encoded_dims": encoded_dims,
|
||||
"delta_dims": delta_dims,
|
||||
"use_delta_transform": use_delta_transform,
|
||||
"state_key": state_key,
|
||||
"normalization_mode": normalization_mode,
|
||||
},
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
safe_repo = repo_id.replace("/", "_")
|
||||
out_file = output_path / f"{timestamp}_{safe_repo}_action_tokenizer_results.json"
|
||||
with open(out_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"Results saved to {out_file}")
|
||||
print("Metrics:")
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, list):
|
||||
print(f" {k}: (length {len(v)})")
|
||||
else:
|
||||
print(f" {k}: {v}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark action tokenization (reconstruction error, compression, timing)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action-tokenizer-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HuggingFace repo id of the trained action tokenizer (e.g. outputs/wavetoken).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="LeRobot dataset repo id (e.g. lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Root directory for LeRobot datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action-horizon",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of future steps per action chunk.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-episodes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max episodes to use (default: all).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-fraction",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Fraction of chunks to sample per episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoded-dims",
|
||||
type=str,
|
||||
default="0:6",
|
||||
help="Dimension ranges to encode (e.g. 0:6,7:14).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta-dims",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated dimensions for delta transform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-delta-transform",
|
||||
action="store_true",
|
||||
help="Apply delta (relative) transform to specified dimensions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state-key",
|
||||
type=str,
|
||||
default=OBS_STATE,
|
||||
help="Dataset key for state (for delta transform).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalization-mode",
|
||||
type=str,
|
||||
default="QUANTILES",
|
||||
choices=[m.value for m in NormalizationMode],
|
||||
help="Normalization mode for actions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-chunks-for-reconstruction",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Max chunks to use for reconstruction metrics (default: 500).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="outputs/action_tokenizer_benchmark",
|
||||
help="Directory to save results JSON (default: outputs/action_tokenizer_benchmark).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(
|
||||
action_tokenizer_path=args.action_tokenizer_path,
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
action_horizon=args.action_horizon,
|
||||
max_episodes=args.max_episodes,
|
||||
sample_fraction=args.sample_fraction,
|
||||
encoded_dims=args.encoded_dims,
|
||||
delta_dims=args.delta_dims,
|
||||
use_delta_transform=args.use_delta_transform,
|
||||
state_key=args.state_key,
|
||||
normalization_mode=args.normalization_mode,
|
||||
max_chunks_for_reconstruction=args.max_chunks_for_reconstruction,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
@@ -7,8 +7,6 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
- local: integrate_hardware
|
||||
@@ -29,6 +27,8 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
@@ -101,11 +101,17 @@
|
||||
title: Earth Rover Mini
|
||||
- local: omx
|
||||
title: OMX
|
||||
- local: openarm
|
||||
title: OpenArm
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
title: Phone
|
||||
title: "Teleoperators"
|
||||
- sections:
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
title: "Sensors"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
|
||||
+95
-81
@@ -1,12 +1,22 @@
|
||||
# Cameras
|
||||
|
||||
LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
LeRobot offers multiple options for video capture:
|
||||
|
||||
### Finding your camera
|
||||
| Class | Supported Cameras |
|
||||
| ----------------- | ----------------------------------- |
|
||||
| `OpenCVCamera` | Phone, built-in laptop, USB webcams |
|
||||
| `ZMQCamera` | Network-connected cameras |
|
||||
| `RealSenseCamera` | Intel RealSense (with depth) |
|
||||
| `Reachy2Camera` | Reachy 2 robot cameras |
|
||||
|
||||
To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system.
|
||||
> [!TIP]
|
||||
> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
|
||||
To find the camera indices of the cameras plugged into your system, run the following script:
|
||||
### Find your camera
|
||||
|
||||
Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices.
|
||||
|
||||
`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system.
|
||||
|
||||
```bash
|
||||
lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
@@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
|
||||
```
|
||||
```bash
|
||||
--- Detected Cameras ---
|
||||
Camera #0:
|
||||
Name: OpenCV Camera @ 0
|
||||
@@ -33,13 +43,37 @@ Camera #0:
|
||||
> [!WARNING]
|
||||
> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable.
|
||||
|
||||
## Use Cameras
|
||||
`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings.
|
||||
|
||||
Below are two examples, demonstrating how to work with the API.
|
||||
## Use cameras
|
||||
|
||||
- **Asynchronous frame capture** using an OpenCV-based camera
|
||||
### Frame access modes
|
||||
|
||||
All camera classes implement three access modes for capturing frames:
|
||||
|
||||
| Method | Behavior | Blocks? | Best For |
|
||||
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- |
|
||||
| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture |
|
||||
| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS |
|
||||
| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring |
|
||||
|
||||
### Usage examples
|
||||
|
||||
The following examples show how to use the camera API to configure and capture frames from different camera types.
|
||||
|
||||
- **Blocking and non-blocking frame capture** using an OpenCV-based camera
|
||||
- **Color and depth capture** using an Intel RealSense camera
|
||||
|
||||
> [!WARNING]
|
||||
> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup:
|
||||
>
|
||||
> ```python
|
||||
> with OpenCVCamera(config) as camera:
|
||||
> ...
|
||||
> ```
|
||||
>
|
||||
> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter.
|
||||
|
||||
<hfoptions id="shell_restart">
|
||||
<hfoption id="Open CV Camera">
|
||||
|
||||
@@ -60,16 +94,30 @@ config = OpenCVCameraConfig(
|
||||
)
|
||||
|
||||
# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default).
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
with OpenCVCamera(config) as camera:
|
||||
|
||||
# Read a frame synchronously — blocks until hardware delivers a new frame
|
||||
frame = camera.read()
|
||||
print(f"read() call returned frame with shape:", frame.shape)
|
||||
|
||||
# Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one
|
||||
try:
|
||||
for i in range(10):
|
||||
frame = camera.async_read(timeout_ms=200)
|
||||
print(f"async_read call returned frame {i} with shape:", frame.shape)
|
||||
except TimeoutError as e:
|
||||
print(f"No frame received within timeout: {e}")
|
||||
|
||||
# Instantly return a frame - returns the most recent frame captured by the camera
|
||||
try:
|
||||
initial_frame = camera.read_latest(max_age_ms=1000)
|
||||
for i in range(10):
|
||||
frame = camera.read_latest(max_age_ms=1000)
|
||||
print(f"read_latest call returned frame {i} with shape:", frame.shape)
|
||||
print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}")
|
||||
except TimeoutError as e:
|
||||
print(f"Frame too old: {e}")
|
||||
|
||||
# Read frames asynchronously in a loop via `async_read(timeout_ms)`
|
||||
try:
|
||||
for i in range(10):
|
||||
frame = camera.async_read(timeout_ms=200)
|
||||
print(f"Async frame {i} shape:", frame.shape)
|
||||
finally:
|
||||
camera.disconnect()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
@@ -111,10 +159,10 @@ finally:
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Use your phone
|
||||
## Use your phone's camera
|
||||
|
||||
<hfoptions id="use phone">
|
||||
<hfoption id="Mac">
|
||||
<hfoption id="iPhone & macOS">
|
||||
|
||||
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
|
||||
@@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
|
||||
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
|
||||
|
||||
Your iPhone should be detected automatically when running the camera setup script in the next section.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
<hfoption id="OBS virtual camera">
|
||||
|
||||
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
|
||||
If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera.
|
||||
|
||||
1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
|
||||
1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
sudo apt install v4l2loopback-dkms v4l-utils
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
|
||||
3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
|
||||
2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android.
|
||||
3. _Download and install [OBS Studio](https://obsproject.com)_.
|
||||
4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_.
|
||||
5. _Start OBS Studio_.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
5. _Start OBS Studio_. Launch with:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
flatpak run com.obsproject.Studio
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
|
||||
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
|
||||
6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks.
|
||||
7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it.
|
||||
8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
|
||||
9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices:
|
||||
9. _Verify the virtual camera setup and resolution_.
|
||||
- **Linux**: Use `v4l2-ctl` to list devices and check resolution:
|
||||
```bash
|
||||
v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path
|
||||
v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path
|
||||
```
|
||||
You should see `VirtualCam` listed and resolution `640x480`.
|
||||
- **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input.
|
||||
- **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
v4l2-ctl --list-devices
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
<details>
|
||||
<summary><strong>Troubleshooting</strong></summary>
|
||||
|
||||
You should see an entry like:
|
||||
> The virtual camera resolution is incorrect.
|
||||
|
||||
```
|
||||
VirtualCam (platform:v4l2loopback-000):
|
||||
/dev/video1
|
||||
```
|
||||
Delete the virtual camera source and recreate it. The resolution cannot be changed after creation.
|
||||
|
||||
10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
|
||||
> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080.
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
v4l2-ctl -d /dev/video1 --get-fmt-video
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`.
|
||||
|
||||
You should see an entry like:
|
||||
|
||||
```
|
||||
>>> Format Video Capture:
|
||||
>>> Width/Height : 640/480
|
||||
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
|
||||
```
|
||||
|
||||
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
|
||||
|
||||
If everything is set up correctly, you can proceed with the rest of the tutorial.
|
||||
</details>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`.
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor.tokenizer_processor import TokenizerProcessor
|
||||
from lerobot.processor.pipeline import ProcessorPipeline
|
||||
|
||||
# Create a tokenizer processor
|
||||
tokenizer_processor = TokenizerProcessor(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -1,13 +1,15 @@
|
||||
# Installation
|
||||
|
||||
## Install [`miniforge`](https://conda-forge.org/download/)
|
||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
||||
|
||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
```
|
||||
|
||||
## Environment Setup
|
||||
## Step 2: Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
|
||||
@@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge
|
||||
>
|
||||
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
## Install LeRobot 🤗
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
# OpenArm
|
||||
|
||||
[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment.
|
||||
|
||||
To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev).
|
||||
|
||||
## What's Unique?
|
||||
|
||||
- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation.
|
||||
|
||||
- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks.
|
||||
|
||||
- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use.
|
||||
|
||||
- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided.
|
||||
|
||||
- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs.
|
||||
|
||||
## Platform Requirements
|
||||
|
||||
<Tip warning={true}>
|
||||
**Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter
|
||||
does not have macOS drivers and has not been tested on Windows.
|
||||
</Tip>
|
||||
|
||||
## Safety Guide
|
||||
|
||||
Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points:
|
||||
|
||||
- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps
|
||||
- **Safe distance**: Keep body parts and objects outside the range of motion during operation
|
||||
- **Protective equipment**: Always wear safety goggles; use additional PPE as needed
|
||||
- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm)
|
||||
- **Emergency stop**: Know the location and operation of the emergency stop device
|
||||
- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage
|
||||
|
||||
## Hardware Setup
|
||||
|
||||
Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for:
|
||||
|
||||
- Bill of materials and sourcing
|
||||
- 3D printing instructions
|
||||
- Mechanical assembly
|
||||
- Electrical wiring
|
||||
|
||||
The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm).
|
||||
|
||||
## CAN Bus Setup
|
||||
|
||||
OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface.
|
||||
|
||||
Quick setup:
|
||||
|
||||
```bash
|
||||
# Setup CAN interfaces
|
||||
lerobot-setup-can --mode=setup --interfaces=can0,can1
|
||||
|
||||
# Test motor communication
|
||||
lerobot-setup-can --mode=test --interfaces=can0,can1
|
||||
```
|
||||
|
||||
## Install LeRobot 🤗
|
||||
|
||||
Follow our [Installation Guide](./installation), then install the Damiao motor support:
|
||||
|
||||
```bash
|
||||
pip install -e ".[damiao]"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Follower Arm (Robot)
|
||||
|
||||
<hfoptions id="follower">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_openarm_follower
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
|
||||
config = OpenArmFollowerConfig(
|
||||
port="can0",
|
||||
side="right", # or "left" for left arm
|
||||
id="my_openarm_follower",
|
||||
)
|
||||
|
||||
follower = OpenArmFollower(config)
|
||||
follower.connect()
|
||||
|
||||
# Read current state
|
||||
obs = follower.get_observation()
|
||||
print(obs)
|
||||
|
||||
# Send action (position in degrees)
|
||||
action = {
|
||||
"joint_1.pos": 0.0,
|
||||
"joint_2.pos": 0.0,
|
||||
"joint_3.pos": 0.0,
|
||||
"joint_4.pos": 45.0,
|
||||
"joint_5.pos": 0.0,
|
||||
"joint_6.pos": 0.0,
|
||||
"joint_7.pos": 0.0,
|
||||
"gripper.pos": 0.0,
|
||||
}
|
||||
follower.send_action(action)
|
||||
|
||||
follower.disconnect()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Leader Arm (Teleoperator)
|
||||
|
||||
The leader arm is used for teleoperation - manually moving it to control the follower arm.
|
||||
|
||||
<hfoptions id="leader">
|
||||
<hfoption id="Command">
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_openarm_leader
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
```python
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
||||
|
||||
config = OpenArmLeaderConfig(
|
||||
port="can1",
|
||||
id="my_openarm_leader",
|
||||
manual_control=True, # Disable torque for manual movement
|
||||
)
|
||||
|
||||
leader = OpenArmLeader(config)
|
||||
leader.connect()
|
||||
|
||||
# Read current position (as action to send to follower)
|
||||
action = leader.get_action()
|
||||
print(action)
|
||||
|
||||
leader.disconnect()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Teleoperation
|
||||
|
||||
To teleoperate OpenArm with leader-follower control:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_follower \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_leader
|
||||
```
|
||||
|
||||
### Bimanual Teleoperation
|
||||
|
||||
To teleoperate a bimanual OpenArm setup with two leader and two follower arms:
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.id=my_bimanual_follower \
|
||||
--teleop.type=bi_openarm_leader \
|
||||
--teleop.left_arm_config.port=can2 \
|
||||
--teleop.right_arm_config.port=can3 \
|
||||
--teleop.id=my_bimanual_leader
|
||||
```
|
||||
|
||||
### Recording Data
|
||||
|
||||
To record a dataset during teleoperation:
|
||||
|
||||
```bash
|
||||
lerobot-record \
|
||||
--robot.type=openarm_follower \
|
||||
--robot.port=can0 \
|
||||
--robot.side=right \
|
||||
--robot.id=my_follower \
|
||||
--teleop.type=openarm_leader \
|
||||
--teleop.port=can1 \
|
||||
--teleop.id=my_leader \
|
||||
--repo-id=my_hf_username/my_openarm_dataset \
|
||||
--fps=30 \
|
||||
--num-episodes=10
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Follower Configuration
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------- | --------- | ---------------------------------------------------------- |
|
||||
| `port` | - | CAN interface (e.g., `can0`) |
|
||||
| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
| `max_relative_target` | `None` | Safety limit for relative target positions |
|
||||
| `position_kp` | Per-joint | Position control proportional gains |
|
||||
| `position_kd` | Per-joint | Position control derivative gains |
|
||||
|
||||
### Leader Configuration
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------ | --------- | ----------------------------------- |
|
||||
| `port` | - | CAN interface (e.g., `can1`) |
|
||||
| `manual_control` | `True` | Disable torque for manual movement |
|
||||
| `use_can_fd` | `True` | Enable CAN FD for higher data rates |
|
||||
| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) |
|
||||
| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) |
|
||||
|
||||
## Motor Configuration
|
||||
|
||||
OpenArm uses Damiao motors with the following default configuration:
|
||||
|
||||
| Joint | Motor Type | Send ID | Recv ID |
|
||||
| --------------------------- | ---------- | ------- | ------- |
|
||||
| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 |
|
||||
| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 |
|
||||
| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 |
|
||||
| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 |
|
||||
| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 |
|
||||
| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 |
|
||||
| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 |
|
||||
| gripper | DM4310 | 0x08 | 0x18 |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No Response from Motors
|
||||
|
||||
1. Check power supply connections
|
||||
2. Verify CAN wiring (CAN-H, CAN-L, GND)
|
||||
3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0`
|
||||
4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details
|
||||
|
||||
### CAN Interface Not Found
|
||||
|
||||
Ensure the CAN interface is configured:
|
||||
|
||||
```bash
|
||||
ip link show can0
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- [OpenArm Website](https://openarm.dev)
|
||||
- [OpenArm Documentation](https://docs.openarm.dev)
|
||||
- [OpenArm GitHub](https://github.com/enactic/openarm)
|
||||
- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide)
|
||||
- [Damiao Motors and CAN Bus](./damiao)
|
||||
@@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy.
|
||||
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
|
||||
You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI.
|
||||
|
||||
### Calibrate Exoskeleton Teleoperator
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo
|
||||
```
|
||||
|
||||
### Teleoperate in Simulation
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset in Simulation
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=true \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
|
||||
---
|
||||
|
||||
## Running on Real Robot
|
||||
|
||||
Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot.
|
||||
|
||||
### Start the Camera Server
|
||||
|
||||
On the robot, start the ZMQ image server:
|
||||
|
||||
```bash
|
||||
python src/lerobot/cameras/zmq/image_server.py
|
||||
```
|
||||
|
||||
Keep this running in a separate terminal for camera streaming during recording.
|
||||
|
||||
### Teleoperate Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-teleoperate \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--fps=100
|
||||
```
|
||||
|
||||
### Record Dataset on Real Robot
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_record \
|
||||
--robot.type=unitree_g1 \
|
||||
--robot.is_simulation=false \
|
||||
--robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \
|
||||
--teleop.type=unitree_g1 \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM1 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.id=exo \
|
||||
--dataset.repo_id=your-username/dataset-name \
|
||||
--dataset.single_task="Test" \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real)
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
|
||||
@@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig):
|
||||
actions = dataset.hf_dataset.select_columns(ACTION)
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
try:
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
key = f"{name.removeprefix('main_')}.pos"
|
||||
action[key] = action_array[i].item()
|
||||
|
||||
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
|
||||
action["elbow_flex.pos"] -= 90
|
||||
robot.send_action(action)
|
||||
action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90)
|
||||
action["elbow_flex.pos"] -= 90
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+45
-43
@@ -78,40 +78,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -120,24 +104,42 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+45
-44
@@ -74,40 +74,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
@@ -115,26 +98,44 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+17
-15
@@ -42,25 +42,27 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -142,38 +142,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -182,24 +168,41 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -149,38 +149,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
@@ -188,25 +173,43 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -73,32 +73,34 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
finally:
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -142,38 +142,24 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -182,24 +168,41 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -146,38 +146,23 @@ def main():
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
@@ -185,25 +170,44 @@ def main():
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -74,32 +74,35 @@ def main():
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i])
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
finally:
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+9
-4
@@ -105,12 +105,17 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
openarms = ["lerobot[damiao]"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0,<2.0.0"
|
||||
"onnxruntime>=1.16.0,<2.0.0",
|
||||
"pin>=3.0.0,<4.0.0",
|
||||
"meshcat>=0.3.0,<0.4.0",
|
||||
"matplotlib>=3.9.0,<4.0.0",
|
||||
"casadi>=3.6.0,<4.0.0",
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
@@ -355,9 +360,9 @@ ignore_errors = false
|
||||
module = "lerobot.cameras.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.motors.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.robots.*"
|
||||
|
||||
@@ -13,5 +13,5 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
|
||||
from .utils import make_cameras_from_configs
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
|
||||
|
||||
from .configs import CameraConfig, ColorMode
|
||||
from .configs import CameraConfig
|
||||
|
||||
|
||||
class Camera(abc.ABC):
|
||||
@@ -30,20 +31,12 @@ class Camera(abc.ABC):
|
||||
|
||||
Manages basic camera properties (FPS, resolution) and core operations:
|
||||
- Connection/disconnection
|
||||
- Frame capture (sync/async)
|
||||
- Frame capture (sync/async/latest)
|
||||
|
||||
Attributes:
|
||||
fps (int | None): Configured frames per second
|
||||
width (int | None): Frame width in pixels
|
||||
height (int | None): Frame height in pixels
|
||||
|
||||
Example:
|
||||
class MyCamera(Camera):
|
||||
def __init__(self, config): ...
|
||||
@property
|
||||
def is_connected(self) -> bool: ...
|
||||
def connect(self, warmup=True): ...
|
||||
# Plus other required methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: CameraConfig):
|
||||
@@ -56,6 +49,32 @@ class Camera(abc.ABC):
|
||||
self.width: int | None = config.width
|
||||
self.height: int | None = config.height
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Context manager entry.
|
||||
Automatically connects to the camera.
|
||||
"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""
|
||||
Context manager exit.
|
||||
Automatically disconnects, ensuring resources are released even on error.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Destructor safety net.
|
||||
Attempts to disconnect if the object is garbage collected without cleanup.
|
||||
"""
|
||||
try:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
@@ -89,12 +108,10 @@ class Camera(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera.
|
||||
def read(self) -> NDArray[Any]:
|
||||
"""Capture and return a single frame from the camera synchronously.
|
||||
|
||||
Args:
|
||||
color_mode: Desired color mode for the output frame. If None,
|
||||
uses the camera's default color mode.
|
||||
This is a blocking call that will wait for the hardware and its SDK.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured frame as a numpy array.
|
||||
@@ -103,17 +120,64 @@ class Camera(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
|
||||
"""Asynchronously capture and return a single frame from the camera.
|
||||
"""Return the most recent new frame.
|
||||
|
||||
This method retrieves the latest frame captured by the background thread.
|
||||
If a new frame is already available in the buffer (captured since the last call),
|
||||
it returns it immediately.
|
||||
|
||||
It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame
|
||||
was already consumed by a previous `async_read` call.
|
||||
|
||||
Essentially, this method return the latest unconsumed frame, waiting if necessary
|
||||
for a new one to arrive within the specified timeout.
|
||||
|
||||
Usage:
|
||||
- Ideal for control loops where you want to ensure every processed frame
|
||||
is fresh, effectively synchronizing your loop to the camera's FPS.
|
||||
- Causes of a timeout usually include: very low camera FPS, heavy processing load,
|
||||
or if the camera is disconnected.
|
||||
|
||||
Args:
|
||||
timeout_ms: Maximum time to wait for a frame in milliseconds.
|
||||
Defaults to implementation-specific timeout.
|
||||
timeout_ms: Maximum time to wait for a new frame in milliseconds.
|
||||
Defaults to 200ms (0.2s).
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured frame as a numpy array.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If no new frame arrives within `timeout_ms`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Usage:
|
||||
Ideal for scenarios requiring zero latency or decoupled frequencies & when
|
||||
we want a guaranteed frame, such as UI visualization, logging, or
|
||||
non-critical monitoring.
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
NotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
warnings.warn(
|
||||
f"{self.__class__.__name__}.read_latest() is not implemented. "
|
||||
"Please override read_latest(); it will be required in future releases.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.async_read()
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the camera and release resources."""
|
||||
|
||||
@@ -25,6 +25,10 @@ class ColorMode(str, Enum):
|
||||
RGB = "rgb"
|
||||
BGR = "bgr"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
class Cv2Rotation(int, Enum):
|
||||
NO_ROTATION = 0
|
||||
@@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum):
|
||||
ROTATE_180 = 180
|
||||
ROTATE_270 = -90
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
|
||||
class Cv2Backends(int, Enum):
|
||||
ANY = 0
|
||||
V4L2 = 200
|
||||
DSHOW = 700
|
||||
PVAPI = 800
|
||||
ANDROID = 1000
|
||||
AVFOUNDATION = 1200
|
||||
MSMF = 1400
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
|
||||
|
||||
@@ -32,10 +32,11 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..utils import get_cv2_backend, get_cv2_rotation
|
||||
from ..utils import get_cv2_rotation
|
||||
from .configuration_opencv import ColorMode, OpenCVCameraConfig
|
||||
|
||||
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
|
||||
@@ -70,34 +71,24 @@ class OpenCVCamera(Camera):
|
||||
Example:
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCamera
|
||||
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation
|
||||
from lerobot.cameras.configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
# Basic usage with camera index 0
|
||||
config = OpenCVCameraConfig(index_or_path=0)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
# When done, properly disconnect the camera using
|
||||
camera.disconnect()
|
||||
|
||||
# Example with custom settings
|
||||
custom_config = OpenCVCameraConfig(
|
||||
index_or_path='/dev/video0', # Or use an index
|
||||
fps=30,
|
||||
width=1280,
|
||||
height=720,
|
||||
color_mode=ColorMode.RGB,
|
||||
rotation=Cv2Rotation.ROTATE_90
|
||||
)
|
||||
custom_camera = OpenCVCamera(custom_config)
|
||||
# ... connect, read, disconnect ...
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -123,10 +114,11 @@ class OpenCVCamera(Camera):
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
self.backend: int = get_cv2_backend()
|
||||
self.backend: int = config.backend
|
||||
|
||||
if self.height and self.width:
|
||||
self.capture_width, self.capture_height = self.width, self.height
|
||||
@@ -141,20 +133,23 @@ class OpenCVCamera(Camera):
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
|
||||
Initializes the OpenCV VideoCapture object, sets desired camera properties
|
||||
(FPS, width, height), and performs initial checks.
|
||||
(FPS, width, height), starts the background reading thread and performs initial checks.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits at connect() time until at least one valid frame
|
||||
has been captured by the background thread. Defaults to True.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings.
|
||||
ConnectionError: If the specified camera index/path is not found or fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
# Use 1 thread for OpenCV operations to avoid potential conflicts or
|
||||
# blocking in multi-threaded applications, especially during data collection.
|
||||
@@ -170,15 +165,20 @@ class OpenCVCamera(Camera):
|
||||
)
|
||||
|
||||
self._configure_capture_settings()
|
||||
self._start_read_thread()
|
||||
|
||||
if warmup:
|
||||
if warmup and self.warmup_s > 0:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.read()
|
||||
self.async_read(timeout_ms=self.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
with self.frame_lock:
|
||||
if self.latest_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@check_if_not_connected
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
|
||||
@@ -196,11 +196,8 @@ class OpenCVCamera(Camera):
|
||||
Raises:
|
||||
RuntimeError: If the camera fails to set any of the specified properties
|
||||
to the requested value.
|
||||
DeviceNotConnectedError: If the camera is not connected when attempting
|
||||
to configure settings.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
@@ -339,6 +336,18 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return found_cameras_info
|
||||
|
||||
def _read_from_hardware(self) -> NDArray[Any]:
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
|
||||
if not ret:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -346,11 +355,6 @@ class OpenCVCamera(Camera):
|
||||
This is a blocking call. It waits for the next available frame from the
|
||||
camera hardware via OpenCV.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
@@ -362,34 +366,31 @@ class OpenCVCamera(Camera):
|
||||
received frame dimensions don't match expectations before rotation.
|
||||
ValueError: If an invalid `color_mode` is requested.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.videocapture is None:
|
||||
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
ret, frame = self.videocapture.read()
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
processed_frame = self._postprocess_image(frame, color_mode)
|
||||
self.new_frame_event.clear()
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return processed_frame
|
||||
return frame
|
||||
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw frame.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The raw image frame (expected BGR format from OpenCV).
|
||||
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
|
||||
uses the instance's default `self.color_mode`.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame.
|
||||
@@ -399,11 +400,10 @@ class OpenCVCamera(Camera):
|
||||
RuntimeError: If the raw frame dimensions do not match the configured
|
||||
`width` and `height`.
|
||||
"""
|
||||
requested_color_mode = self.color_mode if color_mode is None else color_mode
|
||||
|
||||
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
h, w, c = image.shape
|
||||
@@ -417,7 +417,7 @@ class OpenCVCamera(Camera):
|
||||
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
|
||||
|
||||
processed_image = image
|
||||
if requested_color_mode == ColorMode.RGB:
|
||||
if self.color_mode == ColorMode.RGB:
|
||||
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
|
||||
@@ -431,7 +431,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -439,30 +439,37 @@ class OpenCVCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.latest_frame = processed_frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
self._stop_read_thread()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
@@ -475,6 +482,12 @@ class OpenCVCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
@@ -482,6 +495,7 @@ class OpenCVCamera(Camera):
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
It is “best effort” under high FPS.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -496,17 +510,14 @@ class OpenCVCamera(Camera):
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
f"Read thread alive: {self.thread.is_alive()}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
@@ -518,6 +529,41 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera and cleans up resources.
|
||||
@@ -538,4 +584,9 @@ class OpenCVCamera(Camera):
|
||||
self.videocapture.release()
|
||||
self.videocapture = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
|
||||
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
|
||||
warmup_s: Time reading frames before returning from connect (in seconds)
|
||||
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
|
||||
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
@@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
fourcc: str | None = None
|
||||
backend: Cv2Backends = Cv2Backends.ANY
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.rotation not in (
|
||||
Cv2Rotation.NO_ROTATION,
|
||||
Cv2Rotation.ROTATE_90,
|
||||
Cv2Rotation.ROTATE_180,
|
||||
Cv2Rotation.ROTATE_270,
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
self.rotation = Cv2Rotation(self.rotation)
|
||||
self.backend = Cv2Backends(self.backend)
|
||||
|
||||
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
|
||||
raise ValueError(
|
||||
|
||||
@@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
|
||||
@@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
|
||||
if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
@@ -80,6 +81,8 @@ class Reachy2Camera(Camera):
|
||||
self.config = config
|
||||
|
||||
self.color_mode = config.color_mode
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
@@ -121,16 +124,12 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
@@ -139,12 +138,14 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
@@ -165,25 +166,27 @@ class Reachy2Camera(Camera):
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
if self.color_mode == ColorMode.RGB:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
self.latest_frame = frame
|
||||
self.latest_timestamp = time.perf_counter()
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame.
|
||||
|
||||
This method retrieves the most recent frame available in Reachy 2's low-level software.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
Same as read()
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
@@ -194,16 +197,40 @@ class Reachy2Camera(Camera):
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
frame = self.read()
|
||||
return self.read()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: No frame available for {self}.")
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
return frame
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
tuple[NDArray, float]:
|
||||
- The frame image (numpy array).
|
||||
- The timestamp (time.perf_counter) when this frame was captured.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
|
||||
if self.latest_frame is None or self.latest_timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return self.latest_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
@@ -211,8 +238,6 @@ class Reachy2Camera(Camera):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
@@ -30,7 +30,8 @@ try:
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
@@ -72,15 +73,14 @@ class RealSenseCamera(Camera):
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# When done, properly disconnect the camera using
|
||||
camera.disconnect()
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
# Example with depth capture and custom settings
|
||||
custom_config = RealSenseCameraConfig(
|
||||
@@ -133,7 +133,9 @@ class RealSenseCamera(Camera):
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_color_frame: NDArray[Any] | None = None
|
||||
self.latest_depth_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
@@ -151,6 +153,7 @@ class RealSenseCamera(Camera):
|
||||
"""Checks if the camera pipeline is started and streams are active."""
|
||||
return self.rs_pipeline is not None and self.rs_profile is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the RealSense camera specified in the configuration.
|
||||
@@ -158,14 +161,16 @@ class RealSenseCamera(Camera):
|
||||
Initializes the RealSense pipeline, configures the required streams (color
|
||||
and optionally depth), starts the pipeline, and validates the actual stream settings.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits at connect() time until at least one valid frame
|
||||
has been captured by the background thread. Defaults to True.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
|
||||
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
|
||||
RuntimeError: If the pipeline starts but fails to apply requested settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
self.rs_pipeline = rs.pipeline()
|
||||
rs_config = rs.config()
|
||||
@@ -181,15 +186,18 @@ class RealSenseCamera(Camera):
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
self._start_read_thread()
|
||||
|
||||
if warmup:
|
||||
time.sleep(
|
||||
1
|
||||
) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.read()
|
||||
time.sleep(0.1)
|
||||
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
self.warmup_s = max(self.warmup_s, 1)
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.warmup_s:
|
||||
self.async_read(timeout_ms=self.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
with self.frame_lock:
|
||||
if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -282,6 +290,7 @@ class RealSenseCamera(Camera):
|
||||
if self.use_depth:
|
||||
rs_config.enable_stream(rs.stream.depth)
|
||||
|
||||
@check_if_not_connected
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""Sets fps, width, and height from device stream if not already configured.
|
||||
|
||||
@@ -291,8 +300,6 @@ class RealSenseCamera(Camera):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If device is not connected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
|
||||
|
||||
if self.rs_profile is None:
|
||||
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
|
||||
@@ -312,6 +319,7 @@ class RealSenseCamera(Camera):
|
||||
self.width, self.height = actual_width, actual_height
|
||||
self.capture_width, self.capture_height = actual_width, actual_height
|
||||
|
||||
@check_if_not_connected
|
||||
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (depth) synchronously from the camera.
|
||||
@@ -319,9 +327,6 @@ class RealSenseCamera(Camera):
|
||||
This is a blocking call. It waits for a coherent set of frames (depth)
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Args:
|
||||
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The depth map as a NumPy array (height, width)
|
||||
of type `np.uint16` (raw depth values in millimeters) and rotation.
|
||||
@@ -330,44 +335,50 @@ class RealSenseCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If reading frames from the pipeline fails or frames are invalid.
|
||||
"""
|
||||
if timeout_ms:
|
||||
logger.warning(
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
self.new_frame_event.clear()
|
||||
|
||||
_ = self.async_read(timeout_ms=10000)
|
||||
|
||||
with self.frame_lock:
|
||||
depth_map = self.latest_depth_frame
|
||||
|
||||
if depth_map is None:
|
||||
raise RuntimeError("No depth frame available. Ensure camera is streaming.")
|
||||
|
||||
return depth_map
|
||||
|
||||
def _read_from_hardware(self):
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000)
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read_depth failed (status={ret}).")
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
|
||||
depth_frame = frame.get_depth_frame()
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
return frame
|
||||
|
||||
depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return depth_map_processed
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
|
||||
This is a blocking call. It waits for a coherent set of frames (color)
|
||||
from the camera hardware via the RealSense pipeline.
|
||||
|
||||
Args:
|
||||
timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured color frame as a NumPy array
|
||||
(height, width, channels), processed according to `color_mode` and rotation.
|
||||
@@ -378,39 +389,36 @@ class RealSenseCamera(Camera):
|
||||
ValueError: If an invalid `color_mode` is requested.
|
||||
"""
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.rs_pipeline is None:
|
||||
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
|
||||
if timeout_ms:
|
||||
logger.warning(
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"{self} read failed (status={ret}).")
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
color_frame = frame.get_color_frame()
|
||||
color_image_raw = np.asanyarray(color_frame.get_data())
|
||||
self.new_frame_event.clear()
|
||||
|
||||
color_image_processed = self._postprocess_image(color_image_raw, color_mode)
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return color_image_processed
|
||||
return frame
|
||||
|
||||
def _postprocess_image(
|
||||
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
|
||||
) -> NDArray[Any]:
|
||||
def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion, dimension validation, and rotation to a raw color frame.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The raw image frame (expected RGB format from RealSense).
|
||||
color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
|
||||
uses the instance's default `self.color_mode`.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
|
||||
@@ -421,9 +429,9 @@ class RealSenseCamera(Camera):
|
||||
`width` and `height`.
|
||||
"""
|
||||
|
||||
if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
if depth_frame:
|
||||
@@ -454,7 +462,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame with 500ms timeout
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
2. Stores result in latest_frame and updates timestamp (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
@@ -462,25 +470,41 @@ class RealSenseCamera(Camera):
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read(timeout_ms=500)
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
color_frame = np.asanyarray(color_frame_raw.get_data())
|
||||
processed_color_frame = self._postprocess_image(color_frame)
|
||||
|
||||
if self.use_depth:
|
||||
depth_frame_raw = frame.get_depth_frame()
|
||||
depth_frame = np.asanyarray(depth_frame_raw.get_data())
|
||||
processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True)
|
||||
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.latest_color_frame = processed_color_frame
|
||||
if self.use_depth:
|
||||
self.latest_depth_frame = processed_depth_frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
self._stop_read_thread()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
@@ -498,7 +522,14 @@ class RealSenseCamera(Camera):
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame data (color) asynchronously.
|
||||
@@ -506,6 +537,7 @@ class RealSenseCamera(Camera):
|
||||
This method retrieves the most recent color frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
It is “best effort” under high FPS.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
@@ -520,21 +552,18 @@ class RealSenseCamera(Camera):
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread died unexpectedly or another error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
f"Read thread alive: {self.thread.is_alive()}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
frame = self.latest_color_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
@@ -542,6 +571,42 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_color_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
@@ -565,4 +630,10 @@ class RealSenseCamera(Camera):
|
||||
self.rs_pipeline = None
|
||||
self.rs_profile = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig):
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.rotation not in (
|
||||
Cv2Rotation.NO_ROTATION,
|
||||
Cv2Rotation.ROTATE_90,
|
||||
Cv2Rotation.ROTATE_180,
|
||||
Cv2Rotation.ROTATE_270,
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
self.rotation = Cv2Rotation(self.rotation)
|
||||
|
||||
values = (self.fps, self.width, self.height)
|
||||
if any(v is not None for v in values) and any(v is None for v in values):
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
from typing import cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
@@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
|
||||
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_cv2_backend() -> int:
|
||||
import cv2
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
|
||||
# elif platform.system() == "Darwin": # macOS
|
||||
# return cv2.CAP_AVFOUNDATION
|
||||
else: # Linux and others
|
||||
return int(cv2.CAP_ANY)
|
||||
|
||||
@@ -34,7 +34,8 @@ import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
@@ -45,6 +46,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ZMQCamera(Camera):
|
||||
"""
|
||||
Manages camera interactions via ZeroMQ for receiving frames from a remote server.
|
||||
|
||||
This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes
|
||||
incoming JSON messages containing Base64 encoded images. It supports both
|
||||
synchronous and asynchronous frame reading patterns.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
|
||||
@@ -52,7 +59,16 @@ class ZMQCamera(Camera):
|
||||
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
|
||||
camera = ZMQCamera(config)
|
||||
camera.connect()
|
||||
frame = camera.read()
|
||||
|
||||
# Read 1 frame synchronously (blocking)
|
||||
color_image = camera.read()
|
||||
|
||||
# Read 1 frame asynchronously (waits for new frame with a timeout)
|
||||
async_image = camera.async_read()
|
||||
|
||||
# Get the latest frame immediately (no wait, returns timestamp)
|
||||
latest_image, timestamp = camera.read_latest()
|
||||
|
||||
camera.disconnect()
|
||||
```
|
||||
"""
|
||||
@@ -68,14 +84,17 @@ class ZMQCamera(Camera):
|
||||
self.color_mode = config.color_mode
|
||||
self.timeout_ms = config.timeout_ms
|
||||
|
||||
# ZMQ Context and Socket
|
||||
self.context: zmq.Context | None = None
|
||||
self.socket: zmq.Socket | None = None
|
||||
self._connected = False
|
||||
|
||||
# Threading resources
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.latest_timestamp: float | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -83,12 +102,17 @@ class ZMQCamera(Camera):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the ZMQ socket is initialized and connected."""
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""Connect to ZMQ camera server."""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
"""Connect to ZMQ camera server.
|
||||
|
||||
Args:
|
||||
warmup (bool): If True, waits for the camera to provide at least one
|
||||
valid frame before returning. Defaults to True.
|
||||
"""
|
||||
|
||||
logger.info(f"Connecting to {self}...")
|
||||
|
||||
@@ -103,17 +127,28 @@ class ZMQCamera(Camera):
|
||||
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
|
||||
self._connected = True
|
||||
|
||||
# Auto-detect resolution
|
||||
# Auto-detect resolution if not provided
|
||||
if self.width is None or self.height is None:
|
||||
h, w = self.read().shape[:2]
|
||||
# Read directly from hardware because the thread isn't running yet
|
||||
temp_frame = self._read_from_hardware()
|
||||
h, w = temp_frame.shape[:2]
|
||||
self.height = h
|
||||
self.width = w
|
||||
logger.info(f"{self} resolution: {w}x{h}")
|
||||
logger.info(f"{self} resolution detected: {w}x{h}")
|
||||
|
||||
self._start_read_thread()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
if warmup:
|
||||
time.sleep(0.1)
|
||||
# Ensure we have captured at least one frame via the thread
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout
|
||||
self.async_read(timeout_ms=self.config.warmup_s * 1000)
|
||||
time.sleep(0.1)
|
||||
|
||||
with self.frame_lock:
|
||||
if self.latest_frame is None:
|
||||
raise ConnectionError(f"{self} failed to capture frames during warmup.")
|
||||
|
||||
except Exception as e:
|
||||
self._cleanup()
|
||||
@@ -131,15 +166,14 @@ class ZMQCamera(Camera):
|
||||
|
||||
@staticmethod
|
||||
def find_cameras() -> list[dict[str, Any]]:
|
||||
"""ZMQ cameras require manual configuration (server address/port)."""
|
||||
return []
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Read a single frame from the ZMQ camera.
|
||||
Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port).
|
||||
"""
|
||||
raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.")
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
def _read_from_hardware(self) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame directly from the ZMQ socket.
|
||||
"""
|
||||
if not self.is_connected or self.socket is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@@ -147,6 +181,7 @@ class ZMQCamera(Camera):
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
# Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
@@ -176,42 +211,114 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call. It waits for the next available frame from the
|
||||
camera background thread.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if color_mode is not None:
|
||||
logger.warning(
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
self.new_frame_event.clear()
|
||||
frame = self.async_read(timeout_ms=10000)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
while self.stop_event and not self.stop_event.is_set():
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self.read()
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.latest_timestamp = capture_time
|
||||
self.new_frame_event.set()
|
||||
failure_count = 0
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Read error: {e}")
|
||||
except (TimeoutError, Exception) as e:
|
||||
if failure_count <= 10:
|
||||
failure_count += 1
|
||||
logger.warning(f"Read error: {e}")
|
||||
else:
|
||||
raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
if self.thread and self.thread.is_alive():
|
||||
return
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, daemon=True)
|
||||
self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop")
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
if self.stop_event:
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
if self.thread and self.thread.is_alive():
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
|
||||
"""Read latest frame asynchronously (non-blocking)."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if not self.thread or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread is not running.
|
||||
"""
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
|
||||
@@ -225,11 +332,54 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
This method is non-blocking and returns whatever is currently in the
|
||||
memory buffer. The frame may be stale,
|
||||
meaning it could have been captured a while ago (hanging camera scenario e.g.).
|
||||
|
||||
Returns:
|
||||
NDArray[Any]: The frame image (numpy array).
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the latest frame is older than `max_age_ms`.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from ZMQ camera."""
|
||||
if not self.is_connected and not self.thread:
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
self._stop_read_thread()
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
self._cleanup()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = None
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -29,12 +29,10 @@ class ZMQCameraConfig(CameraConfig):
|
||||
camera_name: str = "zmq_camera"
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
timeout_ms: int = 5000
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
|
||||
if self.timeout_ms <= 0:
|
||||
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
|
||||
|
||||
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
||||
normalization mode to apply.
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
@@ -116,6 +116,9 @@ def update_meta_data(
|
||||
Adjusts all indices and timestamps to account for previously aggregated
|
||||
data and videos in the destination dataset.
|
||||
|
||||
For data file indices, uses the 'src_to_dst' mapping from aggregate_data()
|
||||
to correctly map source file indices to their destination locations.
|
||||
|
||||
Args:
|
||||
df: DataFrame containing the metadata to be updated.
|
||||
dst_meta: Destination dataset metadata.
|
||||
@@ -129,8 +132,50 @@ def update_meta_data(
|
||||
|
||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
|
||||
# Update data file indices using source-to-destination mapping
|
||||
# This is critical for handling datasets that are already results of a merge
|
||||
data_src_to_dst = data_idx.get("src_to_dst", {})
|
||||
if data_src_to_dst:
|
||||
# Store original indices for lookup
|
||||
df["_orig_data_chunk"] = df["data/chunk_index"].copy()
|
||||
df["_orig_data_file"] = df["data/file_index"].copy()
|
||||
|
||||
# Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file)
|
||||
# This is much faster than per-row iteration for large metadata tables
|
||||
mapping_index = pd.MultiIndex.from_tuples(
|
||||
list(data_src_to_dst.keys()),
|
||||
names=["chunk_index", "file_index"],
|
||||
)
|
||||
mapping_values = list(data_src_to_dst.values())
|
||||
mapping_df = pd.DataFrame(
|
||||
mapping_values,
|
||||
index=mapping_index,
|
||||
columns=["dst_chunk", "dst_file"],
|
||||
)
|
||||
|
||||
# Construct a MultiIndex for each row based on original data indices
|
||||
row_index = pd.MultiIndex.from_arrays(
|
||||
[df["_orig_data_chunk"], df["_orig_data_file"]],
|
||||
names=["chunk_index", "file_index"],
|
||||
)
|
||||
|
||||
# Align mapping to rows; missing keys fall back to the default destination
|
||||
reindexed = mapping_df.reindex(row_index)
|
||||
reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna(
|
||||
{"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]}
|
||||
)
|
||||
|
||||
# Assign mapped destination indices back to the DataFrame
|
||||
df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy()
|
||||
df["data/file_index"] = reindexed["dst_file"].to_numpy()
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"])
|
||||
else:
|
||||
# Fallback to simple offset (backward compatibility for single-file sources)
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
# Store original video file indices before updating
|
||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||
@@ -146,8 +191,7 @@ def update_meta_data(
|
||||
if src_to_dst:
|
||||
# Map each episode to its correct destination file and apply offset
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
|
||||
# Get destination chunk/file for this source file
|
||||
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||
@@ -163,8 +207,7 @@ def update_meta_data(
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
@@ -262,6 +305,10 @@ def aggregate_datasets(
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
# Clear the src_to_dst mapping after processing each source dataset
|
||||
# to avoid interference between different source datasets
|
||||
data_idx.pop("src_to_dst", None)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
@@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
dst_file_durations = video_idx["dst_file_durations"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
# Convert to Python int to ensure consistent dict keys
|
||||
src_chunk_idx = int(src_chunk_idx)
|
||||
src_file_idx = int(src_file_idx)
|
||||
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
@@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
Reads source data files, updates indices to match the aggregated dataset,
|
||||
and writes them to the destination with proper file rotation.
|
||||
|
||||
Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file)
|
||||
which is critical for correctly updating episode metadata when source datasets
|
||||
have multiple data files (e.g., from a previous merge operation).
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
data_files_size_in_mb: Maximum size for data files in MB.
|
||||
chunk_size: Maximum number of files per chunk.
|
||||
|
||||
Returns:
|
||||
dict: Updated data_idx with current chunk and file indices.
|
||||
@@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
# retrieve features schema for proper image typing in parquet
|
||||
hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None
|
||||
|
||||
# Track source to destination file mapping for metadata update
|
||||
# This is critical for handling datasets that are already results of a merge
|
||||
src_to_dst: dict[tuple[int, int], tuple[int, int]] = {}
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
@@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx = append_or_create_parquet_file(
|
||||
# Write data and get the actual destination file it was written to
|
||||
# This avoids duplicating the rotation logic here
|
||||
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
data_idx,
|
||||
@@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
hf_features=hf_features,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file)
|
||||
|
||||
# Add the mapping to data_idx for use in metadata update
|
||||
data_idx["src_to_dst"] = src_to_dst
|
||||
|
||||
return data_idx
|
||||
|
||||
|
||||
@@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
meta_idx, _ = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
meta_idx,
|
||||
@@ -501,7 +562,7 @@ def append_or_create_parquet_file(
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
):
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
Manages file rotation when size limits are exceeded to prevent individual files
|
||||
@@ -519,9 +580,11 @@ def append_or_create_parquet_file(
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary with current chunk and file indices.
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
and (dst_chunk, dst_file) is the actual destination file the data was written to.
|
||||
"""
|
||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -529,14 +592,15 @@ def append_or_create_parquet_file(
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx
|
||||
return idx, (dst_chunk, dst_file)
|
||||
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= max_mb:
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
dst_chunk, dst_file = idx["chunk"], idx["file"]
|
||||
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
@@ -555,7 +619,7 @@ def append_or_create_parquet_file(
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
return idx
|
||||
return idx, (dst_chunk, dst_file)
|
||||
|
||||
|
||||
def finalize_aggregation(aggr_meta, all_metadata):
|
||||
|
||||
@@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024
|
||||
BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB
|
||||
|
||||
|
||||
def modify_tasks(
|
||||
dataset: LeRobotDataset,
|
||||
new_task: str | None = None,
|
||||
episode_tasks: dict[int, str] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Modify tasks in a LeRobotDataset.
|
||||
|
||||
This function allows you to either:
|
||||
1. Set a single task for the entire dataset (using `new_task`)
|
||||
2. Set specific tasks for specific episodes (using `episode_tasks`)
|
||||
|
||||
You can combine both: `new_task` sets the default, and `episode_tasks` overrides
|
||||
specific episodes.
|
||||
|
||||
The dataset is modified in-place, updating only the task-related files:
|
||||
- meta/tasks.parquet
|
||||
- data/**/*.parquet (task_index column)
|
||||
- meta/episodes/**/*.parquet (tasks column)
|
||||
- meta/info.json (total_tasks)
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset to modify.
|
||||
new_task: A single task string to apply to all episodes. If None and episode_tasks
|
||||
is also None, raises an error.
|
||||
episode_tasks: Optional dict mapping episode indices to their task strings.
|
||||
Overrides `new_task` for specific episodes.
|
||||
|
||||
|
||||
Examples:
|
||||
Set a single task for all episodes:
|
||||
dataset = modify_tasks(dataset, new_task="Pick up the cube")
|
||||
|
||||
Set different tasks for specific episodes:
|
||||
dataset = modify_tasks(
|
||||
dataset,
|
||||
episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"}
|
||||
)
|
||||
|
||||
Set a default task with overrides:
|
||||
dataset = modify_tasks(
|
||||
dataset,
|
||||
new_task="Default task",
|
||||
episode_tasks={5: "Special task for episode 5"}
|
||||
)
|
||||
"""
|
||||
if new_task is None and episode_tasks is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks")
|
||||
|
||||
if episode_tasks is not None:
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = set(episode_tasks.keys()) - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
# Ensure episodes metadata is loaded
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.root)
|
||||
|
||||
# Build the mapping from episode index to task string
|
||||
episode_to_task: dict[int, str] = {}
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
if episode_tasks and ep_idx in episode_tasks:
|
||||
episode_to_task[ep_idx] = episode_tasks[ep_idx]
|
||||
elif new_task is not None:
|
||||
episode_to_task[ep_idx] = new_task
|
||||
else:
|
||||
# Keep original task if not overridden and no default provided
|
||||
original_tasks = dataset.meta.episodes[ep_idx]["tasks"]
|
||||
if not original_tasks:
|
||||
raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided")
|
||||
episode_to_task[ep_idx] = original_tasks[0]
|
||||
|
||||
# Collect all unique tasks and create new task mapping
|
||||
unique_tasks = sorted(set(episode_to_task.values()))
|
||||
new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks)
|
||||
task_to_index = {task: idx for idx, task in enumerate(unique_tasks)}
|
||||
|
||||
logging.info(f"Modifying tasks in {dataset.repo_id}")
|
||||
logging.info(f"New tasks: {unique_tasks}")
|
||||
|
||||
root = dataset.root
|
||||
|
||||
# Update data files - modify task_index column
|
||||
logging.info("Updating data files...")
|
||||
data_dir = root / DATA_DIR
|
||||
|
||||
for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Build a mapping from episode_index to new task_index for rows in this file
|
||||
episode_indices_in_file = df["episode_index"].unique()
|
||||
ep_to_new_task_idx = {
|
||||
ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file
|
||||
}
|
||||
|
||||
# Update task_index column
|
||||
df["task_index"] = df["episode_index"].map(ep_to_new_task_idx)
|
||||
df.to_parquet(parquet_path, index=False)
|
||||
|
||||
# Update episodes metadata - modify tasks column
|
||||
logging.info("Updating episodes metadata...")
|
||||
episodes_dir = root / "meta" / "episodes"
|
||||
|
||||
for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Update tasks column
|
||||
df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]])
|
||||
df.to_parquet(parquet_path, index=False)
|
||||
|
||||
# Write new tasks.parquet
|
||||
write_tasks(new_task_df, root)
|
||||
|
||||
# Update info.json
|
||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||
write_info(dataset.meta.info, root)
|
||||
|
||||
# Reload metadata to reflect changes
|
||||
dataset.meta.tasks = new_task_df
|
||||
dataset.meta.episodes = load_episodes(root)
|
||||
|
||||
logging.info(f"Tasks: {unique_tasks}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
|
||||
@@ -57,6 +57,7 @@ from lerobot.datasets.utils import (
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
@@ -162,6 +163,7 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -518,6 +520,7 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
@@ -1075,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self.features and self.meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -216,16 +216,17 @@ class ImageTransformsConfig:
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
if cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
transform_cls = getattr(v2, cfg.type, None)
|
||||
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
|
||||
return transform_cls(**cfg.kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Transform '{cfg.type}' is not valid. It must be a class in "
|
||||
f"torchvision.transforms.v2 or 'SharpnessJitter'."
|
||||
)
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
|
||||
@@ -60,6 +60,7 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
|
||||
@@ -205,6 +205,7 @@ class ObservationConfig:
|
||||
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ class LiberoEnv(gym.Env):
|
||||
visualization_height: int = 480,
|
||||
init_states: bool = True,
|
||||
episode_index: int = 0,
|
||||
n_envs: int = 1,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
@@ -145,7 +146,9 @@ class LiberoEnv(gym.Env):
|
||||
self.episode_length = episode_length
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
|
||||
|
||||
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
default_steps = 500
|
||||
@@ -295,7 +298,8 @@ class LiberoEnv(gym.Env):
|
||||
self._env.seed(seed)
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
|
||||
self.init_state_id += self._reset_stride # Change init_state_id when reset
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
@@ -373,6 +377,7 @@ def _make_env_fns(
|
||||
init_states=init_states,
|
||||
episode_length=episode_length,
|
||||
episode_index=episode_index,
|
||||
n_envs=n_envs,
|
||||
control_mode=control_mode,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -221,7 +221,7 @@ class RangeFinderGUI:
|
||||
|
||||
self.bus = bus
|
||||
self.groups = groups if groups is not None else {"all": list(bus.motors)}
|
||||
self.group_names = list(groups)
|
||||
self.group_names = list(self.groups)
|
||||
self.current_group = self.group_names[0]
|
||||
|
||||
if not bus.is_connected:
|
||||
@@ -230,18 +230,20 @@ class RangeFinderGUI:
|
||||
self.calibration = bus.read_calibration()
|
||||
self.res_table = bus.model_resolution_table
|
||||
self.present_cache = {
|
||||
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
|
||||
m: bus.read("Present_Position", m, normalize=False)
|
||||
for motors in self.groups.values()
|
||||
for m in motors
|
||||
}
|
||||
|
||||
pygame.init()
|
||||
self.font = pygame.font.Font(None, FONT_SIZE)
|
||||
|
||||
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
|
||||
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
|
||||
self.label_pad = label_pad
|
||||
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
|
||||
self.controls_bottom = 10 + SAVE_H
|
||||
self.base_y = self.controls_bottom + TOP_GAP
|
||||
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
|
||||
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
|
||||
|
||||
self.screen = pygame.display.set_mode((width, height))
|
||||
pygame.display.set_caption("Motors range finder")
|
||||
|
||||
@@ -23,17 +23,20 @@ from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
import can
|
||||
else:
|
||||
can.Message = object
|
||||
can.interface = None
|
||||
|
||||
class can: # noqa: N801
|
||||
Message = object
|
||||
interface = None
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
@@ -152,6 +155,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
"""Check if the CAN bus is connected."""
|
||||
return self._is_connected and self.canbus is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""
|
||||
Open the CAN bus and initialize communication.
|
||||
@@ -159,10 +163,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Args:
|
||||
handshake: If True, ping all motors to verify they're present
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||
)
|
||||
|
||||
try:
|
||||
# Auto-detect interface type based on port name
|
||||
@@ -206,11 +206,34 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Raises ConnectionError if any motor fails to respond.
|
||||
"""
|
||||
logger.info("Starting handshake with motors...")
|
||||
missing_motors = []
|
||||
|
||||
# Drain any pending messages
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
while self.canbus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
missing_motors = []
|
||||
for motor_name in self.motors:
|
||||
msg = self._refresh_motor(motor_name)
|
||||
if msg is None:
|
||||
motor_id = self._get_motor_id(motor_name)
|
||||
recv_id = self._get_motor_recv_id(motor_name)
|
||||
|
||||
# Send enable command
|
||||
data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
# Wait for response with longer timeout
|
||||
response = None
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 0.1:
|
||||
response = self.canbus.recv(timeout=0.1)
|
||||
if response and response.arbitration_id == recv_id:
|
||||
break
|
||||
response = None
|
||||
|
||||
if response is None:
|
||||
missing_motors.append(motor_name)
|
||||
else:
|
||||
self._process_response(motor_name, msg)
|
||||
@@ -223,6 +246,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
)
|
||||
logger.info("Handshake successful. All motors ready.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""
|
||||
Close the CAN bus connection.
|
||||
@@ -230,8 +254,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Args:
|
||||
disable_torque: If True, disable torque on all motors before disconnecting
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
|
||||
|
||||
if disable_torque:
|
||||
try:
|
||||
@@ -259,7 +281,11 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_name = self._get_motor_name(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [command_byte]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
self.canbus.send(msg)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
@@ -317,7 +343,11 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_id = self._get_motor_id(motor)
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
@@ -333,6 +363,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
@@ -371,10 +405,13 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses = {}
|
||||
responses: dict[int, can.Message] = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
# 100us poll timeout
|
||||
@@ -438,8 +475,11 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
@@ -465,6 +505,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
# Step 1: Send all MIT control commands
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
@@ -472,7 +515,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
|
||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||
@@ -539,10 +582,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Refresh motor to get latest state
|
||||
msg = self._refresh_motor(motor)
|
||||
@@ -572,6 +614,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
return mapping[data_name]
|
||||
|
||||
@check_if_not_connected
|
||||
def write(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -582,8 +625,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Write a value to a single motor. Positions are always in degrees.
|
||||
Can write 'Goal_Position', 'Kp', or 'Kd'.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
self._gains[motor][data_name.lower()] = float(value)
|
||||
@@ -633,14 +674,18 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
def _batch_refresh(self, motors: list[str]) -> None:
|
||||
"""Internal helper to refresh a list of motors and update cache."""
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
# Send refresh commands
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False)
|
||||
msg = can.Message(
|
||||
arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd
|
||||
)
|
||||
self.canbus.send(msg)
|
||||
# Small delay to reduce bus congestion if necessary, though removed in sync_read previously
|
||||
# precise_sleep(PRECISE_SLEEP_SEC)
|
||||
|
||||
# Collect responses
|
||||
expected_recv_ids = [self._get_motor_recv_id(m) for m in motors]
|
||||
@@ -655,10 +700,12 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
@check_if_not_connected
|
||||
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
|
||||
"""
|
||||
Write values to multiple motors simultaneously. Positions are always in degrees.
|
||||
"""
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
key = data_name.lower()
|
||||
for motor, val in values.items():
|
||||
@@ -667,6 +714,8 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
elif data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
@@ -676,7 +725,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
kd = self._gains[motor]["kd"]
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||
msg = can.Message(
|
||||
arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd
|
||||
)
|
||||
self.canbus.send(msg)
|
||||
precise_sleep(PRECISE_TIMEOUT_SEC)
|
||||
|
||||
@@ -707,9 +758,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self,
|
||||
motors: NameOrID | list[NameOrID] | None = None,
|
||||
motors: str | list[str] | None = None,
|
||||
display_values: bool = True,
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
) -> tuple[dict[str, Value], dict[str, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
|
||||
@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
for motor, m in self.motors.items():
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=drive_modes[motor],
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
drive_mode=int(drive_modes[motor]),
|
||||
homing_offset=int(offsets[motor]),
|
||||
range_min=int(mins[motor]),
|
||||
range_max=int(maxes[motor]),
|
||||
)
|
||||
|
||||
return calibration
|
||||
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
|
||||
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
On Dynamixel Motors:
|
||||
Present_Position = Actual_Position + Homing_Offset
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
half_turn_homings: dict[NameOrID, Value] = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
|
||||
return
|
||||
return None
|
||||
|
||||
return {id_: data[0] for id_, data in data_list.items()}
|
||||
|
||||
@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
# HACK: monkeypatch
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
|
||||
self.port_handler, scs.PortHandler
|
||||
)
|
||||
self.packet_handler = scs.PacketHandler(protocol_version)
|
||||
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
homing_offset=int(offsets[motor]),
|
||||
range_min=int(mins[motor]),
|
||||
range_max=int(maxes[motor]),
|
||||
)
|
||||
|
||||
return calibration
|
||||
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
On Feetech Motors:
|
||||
Present_Position = Actual_Position - Homing_Offset
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
half_turn_homings: dict[NameOrID, Value] = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
return half_turn_homings
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 1, num_retry=num_retry)
|
||||
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
|
||||
import scservo_sdk as scs
|
||||
|
||||
data_list = {}
|
||||
data_list: dict[int, int] = {}
|
||||
|
||||
status_length = 6
|
||||
|
||||
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
return
|
||||
return None
|
||||
|
||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||
if ids_errors:
|
||||
|
||||
@@ -23,6 +23,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@@ -179,15 +180,16 @@ class Motor:
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool
|
||||
self.baudrate: int
|
||||
self.packet_start_time: float
|
||||
self.packet_timeout: float
|
||||
self.tx_time_per_byte: float
|
||||
self.is_using: bool
|
||||
self.port_name: str
|
||||
self.ser: serial.Serial
|
||||
is_open: bool
|
||||
baudrate: int
|
||||
packet_start_time: float
|
||||
packet_timeout: float
|
||||
tx_time_per_byte: float
|
||||
is_using: bool
|
||||
port_name: str
|
||||
ser: serial.Serial
|
||||
|
||||
def __init__(self, port_name: str) -> None: ...
|
||||
|
||||
def openPort(self): ...
|
||||
def closePort(self): ...
|
||||
@@ -240,19 +242,22 @@ class PacketHandler(Protocol):
|
||||
def regWriteTxRx(self, port, id, address, length, data): ...
|
||||
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
|
||||
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
|
||||
def broadcastPing(self, port): ...
|
||||
|
||||
|
||||
class GroupSyncRead(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.last_result: bool
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
port: str
|
||||
ph: PortHandler
|
||||
start_address: int
|
||||
data_length: int
|
||||
last_result: bool
|
||||
is_param_changed: bool
|
||||
param: list
|
||||
data_dict: dict
|
||||
|
||||
def __init__(
|
||||
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
|
||||
) -> None: ...
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id): ...
|
||||
def removeParam(self, id): ...
|
||||
@@ -265,15 +270,17 @@ class GroupSyncRead(Protocol):
|
||||
|
||||
|
||||
class GroupSyncWrite(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
port: str
|
||||
ph: PortHandler
|
||||
start_address: int
|
||||
data_length: int
|
||||
is_param_changed: bool
|
||||
param: list
|
||||
data_dict: dict
|
||||
|
||||
def __init__(
|
||||
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
|
||||
) -> None: ...
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id, data): ...
|
||||
def removeParam(self, id): ...
|
||||
@@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motor_model(self, motor: NameOrID) -> int:
|
||||
def _get_motor_model(self, motor: NameOrID) -> str:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor].model
|
||||
elif isinstance(motor, int):
|
||||
@@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
|
||||
if motors is None:
|
||||
return list(self.motors)
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors.copy()
|
||||
elif isinstance(motors, int):
|
||||
return [self._id_to_name(motors)]
|
||||
elif isinstance(motors, Sequence):
|
||||
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
|
||||
else:
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
|
||||
if isinstance(values, (int | float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
@@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors.
|
||||
|
||||
Args:
|
||||
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
|
||||
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
|
||||
Defaults to `None`.
|
||||
num_retry (int, optional): Number of additional retry attempts on communication failure.
|
||||
Defaults to 0.
|
||||
"""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: int | str | list[str] | None = None):
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
"""Context-manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
@@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
|
||||
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
|
||||
"""Restore factory calibration for the selected motors.
|
||||
|
||||
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
|
||||
The in-memory :pyattr:`calibration` is cleared.
|
||||
|
||||
Args:
|
||||
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
|
||||
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
|
||||
resets every motor.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
for motor in motors:
|
||||
for motor in motor_names:
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
self.write("Homing_Offset", motor, 0, normalize=False)
|
||||
@@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
self.calibration = {}
|
||||
|
||||
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
|
||||
def set_half_turn_homings(
|
||||
self, motors: NameOrID | Sequence[NameOrID] | None = None
|
||||
) -> dict[NameOrID, Value]:
|
||||
"""Centre each motor range around its current position.
|
||||
|
||||
The function computes and writes a homing offset such that the present position becomes exactly one
|
||||
@@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
|
||||
|
||||
Returns:
|
||||
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
|
||||
dict[str, Value]: Mapping *motor name → written homing offset*.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
self.reset_calibration(motors)
|
||||
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
self.reset_calibration(motor_names)
|
||||
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
homing_offsets = self._get_half_turn_homings(actual_positions)
|
||||
for motor, offset in homing_offsets.items():
|
||||
self.write("Homing_Offset", motor, offset)
|
||||
@@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
pass
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[str, Value], dict[str, Value]]:
|
||||
"""Interactively record the min/max encoder values of each motor.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions. Press
|
||||
@@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
display_values (bool, optional): When `True` (default) a live table is printed to the console.
|
||||
|
||||
Returns:
|
||||
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
|
||||
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
|
||||
extreme values observed for each motor.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
user_pressed_enter = False
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
|
||||
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
|
||||
|
||||
if display_values:
|
||||
print("\n-------------------------------------------")
|
||||
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
|
||||
for motor in motors:
|
||||
for motor in motor_names:
|
||||
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
|
||||
|
||||
if enter_pressed():
|
||||
@@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motors) + 3)
|
||||
move_cursor_up(len(motor_names) + 3)
|
||||
|
||||
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
|
||||
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
|
||||
if same_min_max:
|
||||
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
|
||||
|
||||
@@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
else:
|
||||
return
|
||||
return None
|
||||
if self._is_error(error):
|
||||
if raise_on_error:
|
||||
raise RuntimeError(self.packet_handler.getRxPacketError(error))
|
||||
else:
|
||||
return
|
||||
return None
|
||||
|
||||
return model_number
|
||||
|
||||
@@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
id_value = self._decode_sign(data_name, {id_: value})
|
||||
decoded = self._decode_sign(data_name, {id_: value})
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
id_value = self._normalize(id_value)
|
||||
normalized = self._normalize(decoded)
|
||||
return normalized[id_]
|
||||
|
||||
return id_value[id_]
|
||||
return decoded[id_]
|
||||
|
||||
def _read(
|
||||
self,
|
||||
@@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[int, int, int]:
|
||||
if length == 1:
|
||||
read_fn = self.packet_handler.read1ByteTxRx
|
||||
elif length == 2:
|
||||
@@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
model = self.motors[motor].model
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
int_value = int(value)
|
||||
if normalize and data_name in self.normalized_data:
|
||||
value = self._unnormalize({id_: value})[id_]
|
||||
int_value = self._unnormalize({id_: value})[id_]
|
||||
|
||||
value = self._encode_sign(data_name, {id_: value})[id_]
|
||||
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
|
||||
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _write(
|
||||
self,
|
||||
@@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
motors: NameOrID | Sequence[NameOrID] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
@@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
Args:
|
||||
data_name (str): Register name.
|
||||
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
|
||||
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
|
||||
normalize (bool, optional): Normalisation flag. Defaults to `True`.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
|
||||
@@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||
ids_values, _ = self._sync_read(
|
||||
raw_ids_values, _ = self._sync_read(
|
||||
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
ids_values = self._decode_sign(data_name, ids_values)
|
||||
decoded = self._decode_sign(data_name, raw_ids_values)
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
ids_values = self._normalize(ids_values)
|
||||
normalized = self._normalize(decoded)
|
||||
return {self._id_to_name(id_): value for id_, value in normalized.items()}
|
||||
|
||||
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
||||
return {self._id_to_name(id_): value for id_, value in decoded.items()}
|
||||
|
||||
def _sync_read(
|
||||
self,
|
||||
@@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
raw_ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in raw_ids_values]
|
||||
if self._has_different_ctrl_tables:
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = next(iter(models))
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
|
||||
if normalize and data_name in self.normalized_data:
|
||||
ids_values = self._unnormalize(ids_values)
|
||||
int_ids_values = self._unnormalize(raw_ids_values)
|
||||
|
||||
ids_values = self._encode_sign(data_name, ids_values)
|
||||
int_ids_values = self._encode_sign(data_name, int_ids_values)
|
||||
|
||||
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
|
||||
self._sync_write(
|
||||
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
def _sync_write(
|
||||
self,
|
||||
|
||||
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- Either:
|
||||
@@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig):
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
|
||||
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -73,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
|
||||
@@ -239,8 +239,10 @@ class SACPolicy(
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_temperature(self):
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
@@ -457,11 +459,10 @@ class SACPolicy(
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self):
|
||||
"""Set up temperature parameter and initial log_alpha."""
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
|
||||
@@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
loss_dict["losses_after_forward"] = losses.clone().mean().item()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
|
||||
@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
@@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
|
||||
@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
current step and additional steps going back).
|
||||
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
|
||||
action_chunk_size: Action chunk size of each action prediction token.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
||||
@@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
if ft != FeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
@@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# add our new flattened state
|
||||
state_feats[OBS_STATE] = PolicyFeature(
|
||||
key=OBS_STATE,
|
||||
type=FeatureType.STATE,
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
dtype="float32",
|
||||
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
|
||||
)
|
||||
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
new_features[FeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
|
||||
@@ -18,16 +18,18 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol):
|
||||
|
||||
|
||||
# Type variable constrained to Teleoperator subclasses that also implement events
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator")
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
def _check_teleop_with_events(teleop: "Teleoperator") -> None:
|
||||
"""
|
||||
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
|
||||
|
||||
@@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
teleop_device: The teleoperator instance to get the action from.
|
||||
"""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
teleop_device: "Teleoperator"
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
@@ -312,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
@@ -327,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
The modified transition with the penalty added to complementary data.
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return complementary_data
|
||||
return new_transition
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
return new_transition
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
gripper_action = action[-1].item()
|
||||
@@ -362,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
# Update complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -34,6 +34,8 @@ from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
@@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get("subtask")
|
||||
if subtask is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
@@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
tokenized_subtask = self._tokenize_text(subtask)
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_subtask = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask.items()
|
||||
}
|
||||
|
||||
# Add tokenized subtask to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
|
||||
@@ -412,7 +412,10 @@ def make_processors(
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||
|
||||
if kinematics_solver is not None:
|
||||
add_ee_pose = (
|
||||
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
|
||||
)
|
||||
if kinematics_solver is not None and add_ee_pose:
|
||||
env_pipeline_steps.append(
|
||||
ForwardKinematicsJointsToEEObservation(
|
||||
kinematics=kinematics_solver,
|
||||
@@ -435,7 +438,12 @@ def make_processors(
|
||||
)
|
||||
|
||||
# Add gripper penalty processor if gripper config exists and enabled
|
||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
||||
# Only add if max_gripper_pos is explicitly configured (required for normalization)
|
||||
if (
|
||||
cfg.processor.gripper is not None
|
||||
and cfg.processor.gripper.use_gripper
|
||||
and cfg.processor.max_gripper_pos is not None
|
||||
):
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessorStep(
|
||||
penalty=cfg.processor.gripper.gripper_penalty,
|
||||
|
||||
@@ -545,9 +545,6 @@ def add_actor_information_and_train(
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Update temperature
|
||||
policy.update_temperature()
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
|
||||
@@ -26,8 +26,21 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64
|
||||
) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
|
||||
def _maybe_truncate(tag: str) -> str:
|
||||
"""Truncate tag to max_tag_length characters if required.
|
||||
|
||||
wandb rejects tags longer than 64 characters.
|
||||
See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py
|
||||
"""
|
||||
if len(tag) <= max_tag_length:
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"seed:{cfg.seed}",
|
||||
@@ -36,6 +49,8 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
if truncate_tags:
|
||||
lst = [_maybe_truncate(tag) for tag in lst]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
@@ -83,7 +98,7 @@ class WandBLogger:
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_openarm_follower import BiOpenArmFollower
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
|
||||
__all__ = ["BiOpenArmFollower", "BiOpenArmFollowerConfig"]
|
||||
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmFollower(Robot):
|
||||
"""
|
||||
Bimanual OpenArm Follower Arms
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmFollowerConfig
|
||||
name = "bi_openarm_follower"
|
||||
|
||||
def __init__(self, config: BiOpenArmFollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
side=config.left_arm_config.side,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
can_bitrate=config.left_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
joint_limits=config.left_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=config.right_arm_config.cameras,
|
||||
side=config.right_arm_config.side,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
can_bitrate=config.right_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
joint_limits=config.right_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmFollower(left_arm_config)
|
||||
self.right_arm = OpenArmFollower(right_arm_config)
|
||||
|
||||
# Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute
|
||||
self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras}
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
left_arm_motors_ft = self.left_arm._motors_ft
|
||||
right_arm_motors_ft = self.right_arm._motors_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||
}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_obs = self.right_arm.get_observation()
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
custom_kd: dict[str, float] | None = None,
|
||||
) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_")
|
||||
}
|
||||
# Remove "right_" prefix
|
||||
right_action = {
|
||||
key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_")
|
||||
}
|
||||
|
||||
sent_action_left = self.left_arm.send_action(left_action, custom_kp, custom_kd)
|
||||
sent_action_right = self.right_arm.send_action(right_action, custom_kp, custom_kd)
|
||||
|
||||
# Add prefixes back
|
||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("bi_openarm_follower")
|
||||
@dataclass
|
||||
class BiOpenArmFollowerConfig(RobotConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
@@ -19,6 +19,7 @@ from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_so_follower import BiSOFollowerConfig
|
||||
@@ -96,6 +97,7 @@ class BiSOFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -116,6 +118,7 @@ class BiSOFollower(Robot):
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
@@ -129,6 +132,7 @@ class BiSOFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
@@ -148,6 +152,7 @@ class BiSOFollower(Robot):
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_follower import OpenArmFollowerConfig, OpenArmFollowerConfigBase
|
||||
from .openarm_follower import OpenArmFollower
|
||||
|
||||
__all__ = ["OpenArmFollower", "OpenArmFollowerConfig", "OpenArmFollowerConfigBase"]
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
"joint_1": (-75.0, 75.0),
|
||||
"joint_2": (-90.0, 9.0),
|
||||
"joint_3": (-85.0, 85.0),
|
||||
"joint_4": (0.0, 135.0),
|
||||
"joint_5": (-85.0, 85.0),
|
||||
"joint_6": (-40.0, 40.0),
|
||||
"joint_7": (-80.0, 80.0),
|
||||
"gripper": (-65.0, 0.0),
|
||||
}
|
||||
|
||||
RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
"joint_1": (-75.0, 75.0),
|
||||
"joint_2": (-9.0, 90.0),
|
||||
"joint_3": (-85.0, 85.0),
|
||||
"joint_4": (0.0, 135.0),
|
||||
"joint_5": (-85.0, 85.0),
|
||||
"joint_6": (-40.0, 40.0),
|
||||
"joint_7": (-80.0, 80.0),
|
||||
"gripper": (-65.0, 0.0),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenArmFollowerConfigBase:
|
||||
"""Base configuration for the OpenArms follower robot with Damiao motors."""
|
||||
|
||||
# CAN interfaces - one per arm
|
||||
# arm CAN interface (e.g., "can1")
|
||||
# Linux: "can0", "can1", etc.
|
||||
port: str
|
||||
|
||||
# side of the arm: "left" or "right". If "None" default values will be used
|
||||
side: str | None = None
|
||||
|
||||
# CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect)
|
||||
can_interface: str = "socketcan"
|
||||
|
||||
# CAN FD settings (OpenArms uses CAN FD by default)
|
||||
use_can_fd: bool = True
|
||||
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
|
||||
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
|
||||
|
||||
# Whether to disable torque when disconnecting
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# Safety limit for relative target positions
|
||||
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# Camera configurations
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Motor configuration for OpenArms (7 DOF per arm)
|
||||
# Maps motor names to (send_can_id, recv_can_id, motor_type)
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms uses 4 types of motors:
|
||||
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
|
||||
# - DM4340P and DM4340 for shoulder rotation and elbow
|
||||
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
|
||||
motor_config: dict[str, tuple[int, int, str]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
|
||||
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
|
||||
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
|
||||
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
|
||||
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
|
||||
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
|
||||
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
|
||||
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
|
||||
}
|
||||
)
|
||||
|
||||
# MIT control parameters for position control (used in send_action)
|
||||
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
position_kp: list[float] = field(
|
||||
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0]
|
||||
)
|
||||
position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3])
|
||||
|
||||
# Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'.
|
||||
# If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety.
|
||||
joint_limits: dict[str, tuple[float, float]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (-5.0, 5.0),
|
||||
"joint_2": (-5.0, 5.0),
|
||||
"joint_3": (-5.0, 5.0),
|
||||
"joint_4": (0.0, 5.0),
|
||||
"joint_5": (-5.0, 5.0),
|
||||
"joint_6": (-5.0, 5.0),
|
||||
"joint_7": (-5.0, 5.0),
|
||||
"gripper": (-5.0, 0.0),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("openarm_follower")
|
||||
@dataclass
|
||||
class OpenArmFollowerConfig(RobotConfig, OpenArmFollowerConfigBase):
|
||||
pass
|
||||
@@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .config_openarm_follower import (
|
||||
LEFT_DEFAULT_JOINTS_LIMITS,
|
||||
RIGHT_DEFAULT_JOINTS_LIMITS,
|
||||
OpenArmFollowerConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenArmFollower(Robot):
|
||||
"""
|
||||
OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper.
|
||||
The arm uses Damiao motors in MIT control mode.
|
||||
"""
|
||||
|
||||
config_class = OpenArmFollowerConfig
|
||||
name = "openarm_follower"
|
||||
|
||||
def __init__(self, config: OpenArmFollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Arm motors
|
||||
motors: dict[str, Motor] = {}
|
||||
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
|
||||
motor = Motor(
|
||||
send_id, motor_type_str, MotorNormMode.DEGREES
|
||||
) # Always use degrees for Damiao motors
|
||||
motor.recv_id = recv_id
|
||||
motor.motor_type_str = motor_type_str
|
||||
motors[motor_name] = motor
|
||||
|
||||
self.bus = DamiaoMotorsBus(
|
||||
port=self.config.port,
|
||||
motors=motors,
|
||||
calibration=self.calibration,
|
||||
can_interface=self.config.can_interface,
|
||||
use_can_fd=self.config.use_can_fd,
|
||||
bitrate=self.config.can_bitrate,
|
||||
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
|
||||
)
|
||||
|
||||
if config.side is not None:
|
||||
if config.side == "left":
|
||||
config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS
|
||||
elif config.side == "right":
|
||||
config.joint_limits = RIGHT_DEFAULT_JOINTS_LIMITS
|
||||
else:
|
||||
raise ValueError(
|
||||
"config.side must be either 'left', 'right' (for default values) or 'None' (for CLI values)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Set config.side to either 'left' or 'right' to use pre-configured values for joint limits."
|
||||
)
|
||||
logger.info(f"Values used for joint limits: {config.joint_limits}.")
|
||||
|
||||
# Initialize cameras
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
"""Motor features for observation and action spaces."""
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float # Add this
|
||||
features[f"{motor}.torque"] = float # Add this
|
||||
return features
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
"""Camera features for observation space."""
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
"""Combined observation features from motors and cameras."""
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Action features."""
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if robot is connected."""
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the robot and optionally calibrate.
|
||||
|
||||
We assume that at connection time, the arms are in a safe rest position,
|
||||
and torque can be safely disabled to run calibration if needed.
|
||||
"""
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
self.bus.connect()
|
||||
|
||||
# Run calibration if needed
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||
)
|
||||
self.calibrate()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
|
||||
if self.is_calibrated:
|
||||
self.bus.set_zero_position()
|
||||
|
||||
self.bus.enable_torque()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if robot is calibrated."""
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArms robot.
|
||||
|
||||
The calibration procedure:
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position
|
||||
4. Record range of motion for each joint
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
self.bus.disable_torque()
|
||||
|
||||
# Step 1: Set zero position
|
||||
input(
|
||||
"\nCalibration: Set Zero Position)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
# Set current position as zero for all motors
|
||||
self.bus.set_zero_position()
|
||||
logger.info("Arm zero position set.")
|
||||
|
||||
logger.info("Setting range: -90° to +90° for safety by default for all joints")
|
||||
for motor_name, motor in self.bus.motors.items():
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=0,
|
||||
homing_offset=0,
|
||||
range_min=-90,
|
||||
range_max=90,
|
||||
)
|
||||
|
||||
self.bus.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
print(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Configure motors with appropriate settings."""
|
||||
# TODO(Steven, Pepijn): Slightly different from what it is happening in the leader
|
||||
with self.bus.torque_disabled():
|
||||
self.bus.configure_motors()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Get current observation from robot including position, velocity, and torque.
|
||||
|
||||
Reads all motor states (pos/vel/torque) in one CAN refresh cycle
|
||||
instead of 3 separate reads.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
states = self.bus.sync_read_all_states()
|
||||
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} get_observation took: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
custom_kp: dict[str, float] | None = None,
|
||||
custom_kd: dict[str, float] | None = None,
|
||||
) -> RobotAction:
|
||||
"""
|
||||
Send action command to robot.
|
||||
|
||||
The action magnitude may be clipped based on safety limits.
|
||||
|
||||
Args:
|
||||
action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos")
|
||||
custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0})
|
||||
custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0})
|
||||
|
||||
Returns:
|
||||
The action actually sent (potentially clipped)
|
||||
"""
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Apply joint limit clipping to arm
|
||||
for motor_name, position in goal_pos.items():
|
||||
if motor_name in self.config.joint_limits:
|
||||
min_limit, max_limit = self.config.joint_limits[motor_name]
|
||||
clipped_position = max(min_limit, min(max_limit, position))
|
||||
if clipped_position != position:
|
||||
logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°")
|
||||
goal_pos[motor_name] = clipped_position
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.bus.sync_read("Present_Position")
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
# TODO(Steven, Pepijn): Refactor writing
|
||||
# Motor name to index mapping for gains
|
||||
motor_index = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
# Use batch MIT control for arm (sends all commands, then collects responses)
|
||||
commands = {}
|
||||
for motor_name, position_degrees in goal_pos.items():
|
||||
idx = motor_index.get(motor_name, 0)
|
||||
# Use custom gains if provided, otherwise use config defaults
|
||||
if custom_kp is not None and motor_name in custom_kp:
|
||||
kp = custom_kp[motor_name]
|
||||
else:
|
||||
kp = (
|
||||
self.config.position_kp[idx]
|
||||
if isinstance(self.config.position_kp, list)
|
||||
else self.config.position_kp
|
||||
)
|
||||
if custom_kd is not None and motor_name in custom_kd:
|
||||
kd = custom_kd[motor_name]
|
||||
else:
|
||||
kd = (
|
||||
self.config.position_kd[idx]
|
||||
if isinstance(self.config.position_kd, list)
|
||||
else self.config.position_kd
|
||||
)
|
||||
commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
|
||||
|
||||
self.bus._mit_control_batch(commands)
|
||||
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
"""Disconnect from robot."""
|
||||
|
||||
# Disconnect CAN bus
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
|
||||
# Disconnect cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -65,3 +65,6 @@ class UnitreeG1Config(RobotConfig):
|
||||
|
||||
# Cameras (ZMQ-based remote cameras)
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Compensates for gravity on the unitree's arms using the arm ik solver
|
||||
gravity_compensation: bool = False
|
||||
|
||||
@@ -18,7 +18,7 @@ from enum import IntEnum
|
||||
|
||||
# ruff: noqa: N801, N815
|
||||
|
||||
NUM_MOTORS = 35
|
||||
NUM_MOTORS = 29
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(parent2_dir)
|
||||
|
||||
|
||||
class WeightedMovingFilter:
|
||||
def __init__(self, weights, data_size=14):
|
||||
self._window_size = len(weights)
|
||||
self._weights = np.array(weights)
|
||||
self._data_size = data_size
|
||||
self._filtered_data = np.zeros(self._data_size)
|
||||
self._data_queue = []
|
||||
|
||||
def _apply_filter(self):
|
||||
if len(self._data_queue) < self._window_size:
|
||||
return self._data_queue[-1]
|
||||
|
||||
data_array = np.array(self._data_queue)
|
||||
temp_filtered_data = np.zeros(self._data_size)
|
||||
for i in range(self._data_size):
|
||||
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
|
||||
|
||||
return temp_filtered_data
|
||||
|
||||
def add_data(self, new_data):
|
||||
assert len(new_data) == self._data_size
|
||||
|
||||
if len(self._data_queue) > 0 and np.array_equal(
|
||||
new_data, self._data_queue[-1]
|
||||
): # skip duplicate data
|
||||
return
|
||||
|
||||
if len(self._data_queue) >= self._window_size:
|
||||
self._data_queue.pop(0)
|
||||
|
||||
self._data_queue.append(new_data)
|
||||
self._filtered_data = self._apply_filter()
|
||||
|
||||
@property
|
||||
def filtered_data(self):
|
||||
return self._filtered_data
|
||||
|
||||
|
||||
class G1_29_ArmIK: # noqa: N801
|
||||
def __init__(self, unit_test=False):
|
||||
import casadi
|
||||
import pinocchio as pin
|
||||
from huggingface_hub import snapshot_download
|
||||
from pinocchio import casadi as cpin
|
||||
|
||||
self._pin = pin
|
||||
np.set_printoptions(precision=5, suppress=True, linewidth=200)
|
||||
|
||||
self.unit_test = unit_test
|
||||
|
||||
self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco")
|
||||
urdf_path = os.path.join(self.repo_path, "assets", "g1_body29_hand14.urdf")
|
||||
mesh_dir = os.path.join(self.repo_path, "assets")
|
||||
|
||||
self.robot = self._pin.RobotWrapper.BuildFromURDF(urdf_path, mesh_dir)
|
||||
|
||||
self.mixed_jointsToLockIDs = [
|
||||
"left_hip_pitch_joint",
|
||||
"left_hip_roll_joint",
|
||||
"left_hip_yaw_joint",
|
||||
"left_knee_joint",
|
||||
"left_ankle_pitch_joint",
|
||||
"left_ankle_roll_joint",
|
||||
"right_hip_pitch_joint",
|
||||
"right_hip_roll_joint",
|
||||
"right_hip_yaw_joint",
|
||||
"right_knee_joint",
|
||||
"right_ankle_pitch_joint",
|
||||
"right_ankle_roll_joint",
|
||||
"waist_yaw_joint",
|
||||
"waist_roll_joint",
|
||||
"waist_pitch_joint",
|
||||
"left_hand_thumb_0_joint",
|
||||
"left_hand_thumb_1_joint",
|
||||
"left_hand_thumb_2_joint",
|
||||
"left_hand_middle_0_joint",
|
||||
"left_hand_middle_1_joint",
|
||||
"left_hand_index_0_joint",
|
||||
"left_hand_index_1_joint",
|
||||
"right_hand_thumb_0_joint",
|
||||
"right_hand_thumb_1_joint",
|
||||
"right_hand_thumb_2_joint",
|
||||
"right_hand_index_0_joint",
|
||||
"right_hand_index_1_joint",
|
||||
"right_hand_middle_0_joint",
|
||||
"right_hand_middle_1_joint",
|
||||
]
|
||||
|
||||
self.reduced_robot = self.robot.buildReducedRobot(
|
||||
list_of_joints_to_lock=self.mixed_jointsToLockIDs,
|
||||
reference_configuration=np.array([0.0] * self.robot.model.nq),
|
||||
)
|
||||
|
||||
# Arm joint names in G1 motor order (G1_29_JointArmIndex)
|
||||
self._arm_joint_names_g1 = [
|
||||
"left_shoulder_pitch_joint",
|
||||
"left_shoulder_roll_joint",
|
||||
"left_shoulder_yaw_joint",
|
||||
"left_elbow_joint",
|
||||
"left_wrist_roll_joint",
|
||||
"left_wrist_pitch_joint",
|
||||
"left_wrist_yaw_joint",
|
||||
"right_shoulder_pitch_joint",
|
||||
"right_shoulder_roll_joint",
|
||||
"right_shoulder_yaw_joint",
|
||||
"right_elbow_joint",
|
||||
"right_wrist_roll_joint",
|
||||
"right_wrist_pitch_joint",
|
||||
"right_wrist_yaw_joint",
|
||||
]
|
||||
# Pinocchio uses its own joint order in q; build index mapping.
|
||||
self._arm_joint_names_pin = sorted(
|
||||
self._arm_joint_names_g1,
|
||||
key=lambda name: self.reduced_robot.model.idx_qs[self.reduced_robot.model.getJointId(name)],
|
||||
)
|
||||
logger.info(f"Pinocchio arm joint order: {self._arm_joint_names_pin}")
|
||||
self._arm_reorder_g1_to_pin = [
|
||||
self._arm_joint_names_g1.index(name) for name in self._arm_joint_names_pin
|
||||
]
|
||||
# Inverse mapping to return tau in G1 motor order.
|
||||
self._arm_reorder_pin_to_g1 = np.argsort(self._arm_reorder_g1_to_pin)
|
||||
|
||||
self.reduced_robot.model.addFrame(
|
||||
self._pin.Frame(
|
||||
"L_ee",
|
||||
self.reduced_robot.model.getJointId("left_wrist_yaw_joint"),
|
||||
self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T),
|
||||
self._pin.FrameType.OP_FRAME,
|
||||
)
|
||||
)
|
||||
|
||||
self.reduced_robot.model.addFrame(
|
||||
self._pin.Frame(
|
||||
"R_ee",
|
||||
self.reduced_robot.model.getJointId("right_wrist_yaw_joint"),
|
||||
self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T),
|
||||
self._pin.FrameType.OP_FRAME,
|
||||
)
|
||||
)
|
||||
|
||||
# Creating Casadi models and data for symbolic computing
|
||||
self.cmodel = cpin.Model(self.reduced_robot.model)
|
||||
self.cdata = self.cmodel.createData()
|
||||
|
||||
# Creating symbolic variables
|
||||
self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1)
|
||||
self.cTf_l = casadi.SX.sym("tf_l", 4, 4)
|
||||
self.cTf_r = casadi.SX.sym("tf_r", 4, 4)
|
||||
cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq)
|
||||
|
||||
# Get the hand joint ID and define the error function
|
||||
self.L_hand_id = self.reduced_robot.model.getFrameId("L_ee")
|
||||
self.R_hand_id = self.reduced_robot.model.getFrameId("R_ee")
|
||||
|
||||
self.translational_error = casadi.Function(
|
||||
"translational_error",
|
||||
[self.cq, self.cTf_l, self.cTf_r],
|
||||
[
|
||||
casadi.vertcat(
|
||||
self.cdata.oMf[self.L_hand_id].translation - self.cTf_l[:3, 3],
|
||||
self.cdata.oMf[self.R_hand_id].translation - self.cTf_r[:3, 3],
|
||||
)
|
||||
],
|
||||
)
|
||||
self.rotational_error = casadi.Function(
|
||||
"rotational_error",
|
||||
[self.cq, self.cTf_l, self.cTf_r],
|
||||
[
|
||||
casadi.vertcat(
|
||||
cpin.log3(self.cdata.oMf[self.L_hand_id].rotation @ self.cTf_l[:3, :3].T),
|
||||
cpin.log3(self.cdata.oMf[self.R_hand_id].rotation @ self.cTf_r[:3, :3].T),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Defining the optimization problem
|
||||
self.opti = casadi.Opti()
|
||||
self.var_q = self.opti.variable(self.reduced_robot.model.nq)
|
||||
self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth
|
||||
self.param_tf_l = self.opti.parameter(4, 4)
|
||||
self.param_tf_r = self.opti.parameter(4, 4)
|
||||
self.translational_cost = casadi.sumsqr(
|
||||
self.translational_error(self.var_q, self.param_tf_l, self.param_tf_r)
|
||||
)
|
||||
self.rotation_cost = casadi.sumsqr(
|
||||
self.rotational_error(self.var_q, self.param_tf_l, self.param_tf_r)
|
||||
)
|
||||
self.regularization_cost = casadi.sumsqr(self.var_q)
|
||||
self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last)
|
||||
|
||||
# Setting optimization constraints and goals
|
||||
self.opti.subject_to(
|
||||
self.opti.bounded(
|
||||
self.reduced_robot.model.lowerPositionLimit,
|
||||
self.var_q,
|
||||
self.reduced_robot.model.upperPositionLimit,
|
||||
)
|
||||
)
|
||||
self.opti.minimize(
|
||||
50 * self.translational_cost
|
||||
+ self.rotation_cost
|
||||
+ 0.02 * self.regularization_cost
|
||||
+ 0.1 * self.smooth_cost
|
||||
)
|
||||
|
||||
opts = {
|
||||
"ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6},
|
||||
"print_time": False, # print or not
|
||||
"calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F
|
||||
}
|
||||
self.opti.solver("ipopt", opts)
|
||||
|
||||
self.init_data = np.zeros(self.reduced_robot.model.nq)
|
||||
self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 14)
|
||||
|
||||
def solve_ik(self, left_wrist, right_wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||
if current_lr_arm_motor_q is not None:
|
||||
self.init_data = current_lr_arm_motor_q
|
||||
self.opti.set_initial(self.var_q, self.init_data)
|
||||
|
||||
self.opti.set_value(self.param_tf_l, left_wrist)
|
||||
self.opti.set_value(self.param_tf_r, right_wrist)
|
||||
self.opti.set_value(self.var_q_last, self.init_data) # for smooth
|
||||
|
||||
try:
|
||||
self.opti.solve()
|
||||
|
||||
sol_q = self.opti.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
sol_q,
|
||||
v,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
|
||||
return sol_q, sol_tauff
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
||||
|
||||
sol_q = self.opti.debug.value(self.var_q)
|
||||
self.smooth_filter.add_data(sol_q)
|
||||
sol_q = self.smooth_filter.filtered_data
|
||||
|
||||
if current_lr_arm_motor_dq is not None:
|
||||
v = current_lr_arm_motor_dq * 0.0
|
||||
else:
|
||||
v = (sol_q - self.init_data) * 0.0
|
||||
|
||||
self.init_data = sol_q
|
||||
|
||||
logger.error(
|
||||
f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}"
|
||||
)
|
||||
|
||||
return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv)
|
||||
|
||||
def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None):
|
||||
try:
|
||||
q_g1 = np.array(current_lr_arm_motor_q, dtype=float)
|
||||
if q_g1.shape[0] != len(self._arm_joint_names_g1):
|
||||
raise ValueError(f"Expected {len(self._arm_joint_names_g1)} arm joints, got {q_g1.shape[0]}")
|
||||
q_pin = q_g1[self._arm_reorder_g1_to_pin]
|
||||
sol_tauff = self._pin.rnea(
|
||||
self.reduced_robot.model,
|
||||
self.reduced_robot.data,
|
||||
q_pin,
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
np.zeros(self.reduced_robot.model.nv),
|
||||
)
|
||||
return sol_tauff[self._arm_reorder_pin_to_g1]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR in convergence, plotting debug info.{e}")
|
||||
return np.zeros(self.reduced_robot.model.nv)
|
||||
@@ -27,7 +27,8 @@ import numpy as np
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
@@ -127,6 +128,8 @@ class UnitreeG1(Robot):
|
||||
self.subscribe_thread = None
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
self.arm_ik = G1_29_ArmIK()
|
||||
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
@@ -361,6 +364,20 @@ class UnitreeG1(Robot):
|
||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
||||
self.msg.motor_cmd[motor.value].tau = 0
|
||||
|
||||
if self.config.gravity_compensation:
|
||||
# Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13)
|
||||
action_np = np.zeros(14)
|
||||
arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
action_np[local_idx] = self.msg.motor_cmd[joint.value].q
|
||||
tau = self.arm_ik.solve_tau(action_np)
|
||||
|
||||
# Apply tau back to motor commands
|
||||
for joint in G1_29_JointArmIndex:
|
||||
local_idx = joint.value - arm_start_idx
|
||||
self.msg.motor_cmd[joint.value].tau = tau[local_idx]
|
||||
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
return action
|
||||
|
||||
@@ -60,6 +60,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .reachy2 import Reachy2Robot
|
||||
|
||||
return Reachy2Robot(config)
|
||||
elif config.type == "openarm_follower":
|
||||
from .openarm_follower import OpenArmFollower
|
||||
|
||||
return OpenArmFollower(config)
|
||||
elif config.type == "bi_openarm_follower":
|
||||
from .bi_openarm_follower import BiOpenArmFollower
|
||||
|
||||
return BiOpenArmFollower(config)
|
||||
elif config.type == "mock_robot":
|
||||
from tests.mocks.mock_robot import MockRobot
|
||||
|
||||
|
||||
@@ -36,23 +36,28 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
lekiwi,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.utils import init_logging
|
||||
@@ -81,8 +86,11 @@ def calibrate(cfg: CalibrateConfig):
|
||||
device = make_teleoperator_from_config(cfg.device)
|
||||
|
||||
device.connect(calibrate=False)
|
||||
device.calibrate()
|
||||
device.disconnect()
|
||||
|
||||
try:
|
||||
device.calibrate()
|
||||
finally:
|
||||
device.disconnect()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
Edit LeRobot datasets using various transformation tools.
|
||||
|
||||
This script allows you to delete episodes, split datasets, merge datasets,
|
||||
remove features, and convert image datasets to video format.
|
||||
remove features, modify tasks, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Usage Examples:
|
||||
@@ -66,6 +66,25 @@ Remove camera feature:
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Modify tasks - set a single task for all episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Pick up the cube and place it"
|
||||
|
||||
Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}'
|
||||
|
||||
Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type modify_tasks \
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -100,6 +119,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
@@ -132,6 +152,13 @@ class RemoveFeatureConfig:
|
||||
feature_names: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModifyTasksConfig:
|
||||
type: str = "modify_tasks"
|
||||
new_task: str | None = None
|
||||
episode_tasks: dict[str, str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig:
|
||||
type: str = "convert_image_to_video"
|
||||
@@ -151,7 +178,12 @@ class ConvertImageToVideoConfig:
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: (
|
||||
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig
|
||||
DeleteEpisodesConfig
|
||||
| SplitConfig
|
||||
| MergeConfig
|
||||
| RemoveFeatureConfig
|
||||
| ModifyTasksConfig
|
||||
| ConvertImageToVideoConfig
|
||||
)
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
@@ -296,6 +328,48 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, ModifyTasksConfig):
|
||||
raise ValueError("Operation config must be ModifyTasksConfig")
|
||||
|
||||
new_task = cfg.operation.new_task
|
||||
episode_tasks_raw = cfg.operation.episode_tasks
|
||||
|
||||
if new_task is None and episode_tasks_raw is None:
|
||||
raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation")
|
||||
|
||||
# Warn about in-place modification behavior
|
||||
if cfg.new_repo_id is not None:
|
||||
logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.")
|
||||
|
||||
# Convert episode_tasks keys from string to int if needed (CLI passes strings)
|
||||
episode_tasks: dict[int, str] | None = None
|
||||
if episode_tasks_raw is not None:
|
||||
episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()}
|
||||
|
||||
logging.info(f"Modifying tasks in {cfg.repo_id}")
|
||||
if new_task:
|
||||
logging.info(f" Default task: '{new_task}'")
|
||||
if episode_tasks:
|
||||
logging.info(f" Episode-specific tasks: {episode_tasks}")
|
||||
|
||||
modified_dataset = modify_tasks(
|
||||
dataset,
|
||||
new_task=new_task,
|
||||
episode_tasks=episode_tasks,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset modified at {dataset.root}")
|
||||
logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {cfg.repo_id}")
|
||||
modified_dataset.push_to_hub()
|
||||
|
||||
|
||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
@@ -371,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_merge(cfg)
|
||||
elif operation_type == "remove_feature":
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "modify_tasks":
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -44,19 +44,23 @@ import numpy as np
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -98,26 +98,31 @@ from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
unitree_g1 as unitree_g1_robot,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
|
||||
@@ -53,12 +53,14 @@ from lerobot.processor import (
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
@@ -108,25 +110,26 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
try:
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
robot_obs = robot.get_observation()
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
processed_action = robot_action_processor((action, robot_obs))
|
||||
processed_action = robot_action_processor((action, robot_obs))
|
||||
|
||||
_ = robot.send_action(processed_action)
|
||||
_ = robot.send_action(processed_action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(max(1 / dataset.fps - dt_s, 0.0))
|
||||
finally:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -45,7 +45,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
MOTOR_NAMES = {
|
||||
0x01: "joint_1",
|
||||
@@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig):
|
||||
|
||||
@draccus.wrap()
|
||||
def setup_can(cfg: CANSetupConfig):
|
||||
if not is_package_available("can"):
|
||||
if not _can_available:
|
||||
print("Error: python-can not installed. Install with: pip install python-can")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -70,18 +70,22 @@ from lerobot.processor import (
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1 as unitree_g1_robot,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
homunculus,
|
||||
@@ -89,8 +93,10 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -166,9 +166,9 @@ def apply_normalization(
|
||||
if q01 is None or q99 is None:
|
||||
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
|
||||
denom = np.maximum(q99 - q01, eps)
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q01, q99)
|
||||
return 2.0 * (clipped - q01) / denom - 1.0
|
||||
# No clipping: match training pipeline NormalizerProcessorStep so tokenizer
|
||||
# is fit on the full range of normalized values (including tails outside [-1, 1]).
|
||||
return 2.0 * (data - q01) / denom - 1.0
|
||||
|
||||
if mode == NormalizationMode.QUANTILE10:
|
||||
q10 = stats.get("q10")
|
||||
@@ -176,9 +176,8 @@ def apply_normalization(
|
||||
if q10 is None or q90 is None:
|
||||
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
|
||||
denom = np.maximum(q90 - q10, eps)
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q10, q90)
|
||||
return 2.0 * (clipped - q10) / denom - 1.0
|
||||
# No clipping: match training pipeline NormalizerProcessorStep.
|
||||
return 2.0 * (data - q10) / denom - 1.0
|
||||
|
||||
raise ValueError(f"Unsupported normalization mode: {mode}")
|
||||
|
||||
@@ -306,7 +305,7 @@ def train_fast_tokenizer(
|
||||
|
||||
# download the tokenizer source code (not pretrained weights)
|
||||
# we'll train a new tokenizer on our own data
|
||||
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
base_tokenizer = AutoProcessor.from_pretrained("/fsx/jade_choghari/outputs/libero_tokenizer_wavetoken1", trust_remote_code=True)
|
||||
|
||||
# convert action_chunks array to list of arrays (expected by .fit())
|
||||
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
||||
@@ -320,6 +319,8 @@ def train_fast_tokenizer(
|
||||
vocab_size=vocab_size,
|
||||
time_horizon=action_chunks.shape[1], # action_horizon
|
||||
action_dim=action_chunks.shape[2], # encoded dimensions
|
||||
wavelet="dmey",
|
||||
level=1,
|
||||
)
|
||||
print("✓ Tokenizer training complete!")
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
||||
|
||||
__all__ = ["BiOpenArmLeader", "BiOpenArmLeaderConfig"]
|
||||
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..openarm_leader import OpenArmLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmLeader(Teleoperator):
|
||||
"""
|
||||
Bimanual OpenArm Leader Arms
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmLeaderConfig
|
||||
name = "bi_openarm_leader"
|
||||
|
||||
def __init__(self, config: BiOpenArmLeaderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
left_arm_config = OpenArmLeaderConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
can_bitrate=config.left_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
manual_control=config.left_arm_config.manual_control,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmLeaderConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
can_bitrate=config.right_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
manual_control=config.right_arm_config.manual_control,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmLeader(left_arm_config)
|
||||
self.right_arm = OpenArmLeader(right_arm_config)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
left_arm_features = self.left_arm.action_features
|
||||
right_arm_features = self.right_arm.action_features
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_features.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_features.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_action = self.left_arm.get_action()
|
||||
action_dict.update({f"left_{key}": value for key, value in left_action.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_action = self.right_arm.get_action()
|
||||
action_dict.update({f"right_{key}": value for key, value in right_action.items()})
|
||||
|
||||
return action_dict
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
||||
@dataclass
|
||||
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
|
||||
left_arm_config: OpenArmLeaderConfigBase
|
||||
right_arm_config: OpenArmLeaderConfigBase
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..so_leader import SOLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -72,6 +72,7 @@ class BiSOLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -110,6 +111,7 @@ class BiSOLeader(Teleoperator):
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_leader import OpenArmLeaderConfig, OpenArmLeaderConfigBase
|
||||
from .openarm_leader import OpenArmLeader
|
||||
|
||||
__all__ = ["OpenArmLeader", "OpenArmLeaderConfig", "OpenArmLeaderConfigBase"]
|
||||
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenArmLeaderConfigBase:
|
||||
"""Base configuration for the OpenArms leader/teleoperator with Damiao motors."""
|
||||
|
||||
# CAN interfaces - one per arm
|
||||
# Arm CAN interface (e.g., "can3")
|
||||
# Linux: "can0", "can1", etc.
|
||||
port: str
|
||||
|
||||
# CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect)
|
||||
can_interface: str = "socketcan"
|
||||
|
||||
# CAN FD settings (OpenArms uses CAN FD by default)
|
||||
use_can_fd: bool = True
|
||||
can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps)
|
||||
can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps)
|
||||
|
||||
# Motor configuration for OpenArms (7 DOF per arm)
|
||||
# Maps motor names to (send_can_id, recv_can_id, motor_type)
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms uses 4 types of motors:
|
||||
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
|
||||
# - DM4340P and DM4340 for shoulder rotation and elbow
|
||||
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
|
||||
motor_config: dict[str, tuple[int, int, str]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
|
||||
"joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009)
|
||||
"joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340)
|
||||
"joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340)
|
||||
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
|
||||
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
|
||||
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
|
||||
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
|
||||
}
|
||||
)
|
||||
|
||||
# Torque mode settings for manual control
|
||||
# When enabled, motors have torque disabled for manual movement
|
||||
manual_control: bool = True
|
||||
|
||||
# TODO(Steven, Pepijn): Not used ... ?
|
||||
# MIT control parameters (used when manual_control=False for torque control)
|
||||
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
position_kp: list[float] = field(
|
||||
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0]
|
||||
)
|
||||
position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2])
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_leader")
|
||||
@dataclass
|
||||
class OpenArmLeaderConfig(TeleoperatorConfig, OpenArmLeaderConfigBase):
|
||||
pass
|
||||
@@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_leader import OpenArmLeaderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenArmLeader(Teleoperator):
|
||||
"""
|
||||
OpenArm Leader/Teleoperator Arm with Damiao motors.
|
||||
|
||||
This teleoperator uses CAN bus communication to read positions from
|
||||
Damiao motors that are manually moved (torque disabled).
|
||||
"""
|
||||
|
||||
config_class = OpenArmLeaderConfig
|
||||
name = "openarm_leader"
|
||||
|
||||
def __init__(self, config: OpenArmLeaderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Arm motors
|
||||
motors: dict[str, Motor] = {}
|
||||
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
|
||||
motor = Motor(
|
||||
send_id, motor_type_str, MotorNormMode.DEGREES
|
||||
) # Always use degrees for Damiao motors
|
||||
motor.recv_id = recv_id
|
||||
motor.motor_type_str = motor_type_str
|
||||
motors[motor_name] = motor
|
||||
|
||||
self.bus = DamiaoMotorsBus(
|
||||
port=self.config.port,
|
||||
motors=motors,
|
||||
calibration=self.calibration,
|
||||
can_interface=self.config.can_interface,
|
||||
use_can_fd=self.config.use_can_fd,
|
||||
bitrate=self.config.can_bitrate,
|
||||
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Features produced by this teleoperator."""
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
"""Feedback features (not implemented for OpenArms)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if teleoperator is connected."""
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the teleoperator.
|
||||
|
||||
For manual control, we disable torque after connecting so the
|
||||
arm can be moved by hand.
|
||||
"""
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
self.bus.connect()
|
||||
|
||||
# Run calibration if needed
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||
)
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
|
||||
if self.is_calibrated:
|
||||
self.bus.set_zero_position()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""Check if teleoperator is calibrated."""
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArms leader.
|
||||
|
||||
The calibration procedure:
|
||||
1. Disable torque (if not already disabled)
|
||||
2. Ask user to position arm in zero position (hanging with gripper closed)
|
||||
3. Set this as zero position
|
||||
4. Record range of motion for each joint
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
self.bus.disable_torque()
|
||||
|
||||
# Step 1: Set zero position
|
||||
input(
|
||||
"\nCalibration: Set Zero Position)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
# Set current position as zero for all motors
|
||||
self.bus.set_zero_position()
|
||||
logger.info("Arm zero position set.")
|
||||
|
||||
logger.info("Setting range: -90° to +90° by default for all joints")
|
||||
# TODO(Steven, Pepijn): Check if MotorCalibration is actually needed here given that we only use Degrees
|
||||
for motor_name, motor in self.bus.motors.items():
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=0,
|
||||
homing_offset=0,
|
||||
range_min=-90,
|
||||
range_max=90,
|
||||
)
|
||||
|
||||
self.bus.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
print(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
"""
|
||||
Configure motors for manual teleoperation.
|
||||
|
||||
For manual control, we disable torque so the arm can be moved by hand.
|
||||
"""
|
||||
|
||||
return self.bus.disable_torque() if self.config.manual_control else self.bus.configure_motors()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get current action from the leader arm.
|
||||
|
||||
This is the main method for teleoperators - it reads the current state
|
||||
of the leader arm and returns it as an action that can be sent to a follower.
|
||||
|
||||
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
action_dict: dict[str, Any] = {}
|
||||
|
||||
# Use sync_read_all_states to get pos/vel/torque in one go
|
||||
states = self.bus.sync_read_all_states()
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
action_dict[f"{motor}.pos"] = state.get("position")
|
||||
action_dict[f"{motor}.vel"] = state.get("velocity")
|
||||
action_dict[f"{motor}.torque"] = state.get("torque")
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
return action_dict
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from teleoperator."""
|
||||
|
||||
# Disconnect CAN bus
|
||||
# For manual control, ensure torque is disabled before disconnecting
|
||||
self.bus.disconnect(disable_torque=self.config.manual_control)
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_unitree_g1 import ExoskeletonArmPortConfig, UnitreeG1TeleoperatorConfig
|
||||
from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration
|
||||
from .exo_ik import ExoskeletonIKHelper
|
||||
from .exo_serial import ExoskeletonArm
|
||||
from .unitree_g1 import UnitreeG1Teleoperator
|
||||
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExoskeletonArmPortConfig:
|
||||
"""Serial port configuration for individual exoskeleton arm."""
|
||||
|
||||
port: str = ""
|
||||
baud_rate: int = 115200
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("unitree_g1")
|
||||
@dataclass
|
||||
class UnitreeG1TeleoperatorConfig(TeleoperatorConfig):
|
||||
left_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig)
|
||||
right_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig)
|
||||
|
||||
# Frozen joints (comma-separated joint names that won't be moved by IK)
|
||||
frozen_joints: str = ""
|
||||
@@ -0,0 +1,446 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This module handles calibration of hall effect sensors used in the exoskeleton.
|
||||
Each joint has a pair of ADC channels outputting sin and cos values that trace an ellipse
|
||||
as the joint rotates due to imprecision in magnet/sensor placement. We fit this ellipse to a unit circle,
|
||||
and calculate arctan2 of the unit circle to get the joint angle.
|
||||
We then store the ellipse parameters and the zero offset for each joint to be used at runtime.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import serial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw
|
||||
JOINTS = {
|
||||
"shoulder_pitch": (0, 1),
|
||||
"shoulder_yaw": (2, 3),
|
||||
"shoulder_roll": (4, 5),
|
||||
"elbow_flex": (6, 7),
|
||||
"wrist_roll": (14, 15),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExoskeletonJointCalibration:
|
||||
name: str # joint name
|
||||
center_fit: list[float] # center of the ellipse
|
||||
T: list[list[float]] # 2x2 transformation matrix
|
||||
zero_offset: float = 0.0 # angle at neutral pose
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExoskeletonCalibration:
|
||||
"""Full calibration data for an exoskeleton arm."""
|
||||
|
||||
version: int = 2
|
||||
side: str = ""
|
||||
adc_max: int = 2**12 - 1
|
||||
joints: list[ExoskeletonJointCalibration] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"version": self.version,
|
||||
"side": self.side,
|
||||
"adc_max": self.adc_max,
|
||||
"joints": [
|
||||
{
|
||||
"name": j.name,
|
||||
"center_fit": j.center_fit,
|
||||
"T": j.T,
|
||||
"zero_offset": j.zero_offset,
|
||||
}
|
||||
for j in self.joints
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ExoskeletonCalibration":
|
||||
joints = [
|
||||
ExoskeletonJointCalibration(
|
||||
name=j["name"],
|
||||
center_fit=j["center_fit"],
|
||||
T=j["T"],
|
||||
zero_offset=j.get("zero_offset", 0.0),
|
||||
)
|
||||
for j in data.get("joints", [])
|
||||
]
|
||||
return cls(
|
||||
version=data.get("version", 2),
|
||||
side=data.get("side", ""),
|
||||
adc_max=data.get("adc_max", 2**12 - 1),
|
||||
joints=joints,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CalibParams:
|
||||
fit_every: float = 0.15
|
||||
min_fit_points: int = 60
|
||||
fit_window: int = 900
|
||||
max_fit_points: int = 300
|
||||
trim_low: float = 0.05
|
||||
trim_high: float = 0.95
|
||||
median_window: int = 5
|
||||
history: int = 3500
|
||||
draw_hz: float = 120.0
|
||||
sample_count: int = 50
|
||||
|
||||
|
||||
def normalize_angle(angle: float) -> float:
|
||||
while angle > np.pi:
|
||||
angle -= 2 * np.pi
|
||||
while angle < -np.pi:
|
||||
angle += 2 * np.pi
|
||||
return angle
|
||||
|
||||
|
||||
def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]:
|
||||
"""
|
||||
Applies calibration to each joint: raw → centered → ellipse-to-circle → angle.
|
||||
"""
|
||||
pair = JOINTS[j.name]
|
||||
s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos
|
||||
p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values
|
||||
z = np.asarray(j.T) @ (
|
||||
p - np.asarray(j.center_fit)
|
||||
) # center the ellipse and invert the transformation matrix to get unit circle coords
|
||||
ang = float(np.arctan2(z[1], z[0])) - j.zero_offset # calculate the anvgle and apply the zero offset
|
||||
return z, normalize_angle(-ang) # ensure range is [-pi, pi]
|
||||
|
||||
|
||||
def exo_raw_to_angles(raw16: list[int], calib: ExoskeletonCalibration) -> dict[str, float]:
|
||||
"""Convert raw sensor readings to joint angles using calibration."""
|
||||
return {j.name: joint_z_and_angle(raw16, j)[1] for j in calib.joints}
|
||||
|
||||
|
||||
def run_exo_calibration(
|
||||
ser: serial.Serial,
|
||||
side: str,
|
||||
save_path: Path,
|
||||
params: CalibParams | None = None,
|
||||
) -> ExoskeletonCalibration:
|
||||
"""
|
||||
Run interactive calibration for an exoskeleton arm.
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Calibration requires matplotlib and opencv-python. "
|
||||
"Install with: pip install matplotlib opencv-python"
|
||||
) from e
|
||||
|
||||
from .exo_serial import read_raw_from_serial
|
||||
|
||||
params = params or CalibParams()
|
||||
joint_list = list(JOINTS.items()) # Convert dict to list for indexing
|
||||
logger.info(f"Starting calibration for {side} exoskeleton arm")
|
||||
|
||||
def running_median(win: deque) -> float:
|
||||
return float(np.median(np.fromiter(win, dtype=float)))
|
||||
|
||||
def read_joint_point(raw16: list[int], pair: tuple[int, int]):
|
||||
s, c = raw16[pair[0]], raw16[pair[1]]
|
||||
return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c)
|
||||
|
||||
def select_fit_subset(xs, ys):
|
||||
"""Select and filter points for ellipse fitting. Trims outliers by radius and downsamples."""
|
||||
n = min(params.fit_window, len(xs))
|
||||
if n <= 0:
|
||||
return None, None
|
||||
x = np.asarray(list(xs)[-n:], dtype=float) # most recent n samples
|
||||
y = np.asarray(list(ys)[-n:], dtype=float)
|
||||
r = np.sqrt(x * x + y * y) # radius from origin
|
||||
if len(r) >= 20:
|
||||
lo, hi = np.quantile(r, params.trim_low), np.quantile(r, params.trim_high) # outlier bounds
|
||||
keep = (r >= lo) & (r <= hi)
|
||||
x, y = x[keep], y[keep] # remove outliers
|
||||
if len(x) > params.max_fit_points:
|
||||
idx = np.linspace(0, len(x) - 1, params.max_fit_points).astype(int) # downsample evenly
|
||||
x, y = x[idx], y[idx]
|
||||
return x, y
|
||||
|
||||
def fit_ellipse_opencv(x, y):
|
||||
"""Fit ellipse to (x,y) points using OpenCV. Returns center, axes, rotation matrix, and outline."""
|
||||
x, y = np.asarray(x, dtype=float), np.asarray(y, dtype=float)
|
||||
if len(x) < 5:
|
||||
return None
|
||||
pts = np.stack([x, y], axis=1).astype(np.float32).reshape(-1, 1, 2)
|
||||
try:
|
||||
(xc, yc), (w, h), angle_deg = cv2.fitEllipse(pts) # returns center, axes, rotation in degrees
|
||||
except cv2.error:
|
||||
return None
|
||||
a, b = float(w) * 0.5, float(h) * 0.5 # get ellipse major and minor semi-axes
|
||||
phi = np.deg2rad(float(angle_deg)) # to rad
|
||||
if b > a: # ensure major axis is a
|
||||
a, b = b, a
|
||||
phi += np.pi / 2.0
|
||||
if not np.isfinite(a) or not np.isfinite(b) or a <= 1e-6 or b <= 1e-6:
|
||||
return None
|
||||
cp, sp = float(np.cos(phi)), float(np.sin(phi)) #
|
||||
rot = np.array([[cp, -sp], [sp, cp]], dtype=float) # 2x2 rotation matrix
|
||||
center = np.array([float(xc), float(yc)], dtype=float) # offset vector
|
||||
tt = np.linspace(0, 2 * np.pi, 360)
|
||||
outline = (rot @ np.stack([a * np.cos(tt), b * np.sin(tt)])).T + center # for viz
|
||||
return {"center": center, "a": a, "b": b, "R": rot, "ex": outline[:, 0], "ey": outline[:, 1]}
|
||||
|
||||
# Setup matplotlib
|
||||
plt.ion()
|
||||
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12, 6))
|
||||
ax0.set_xlabel("cos - center")
|
||||
ax0.set_ylabel("sin - center")
|
||||
ax0.grid(True, alpha=0.25)
|
||||
ax0.set_aspect("equal", adjustable="box")
|
||||
ax1.set_title("Unit circle + angle")
|
||||
ax1.set_xlabel("x")
|
||||
ax1.set_ylabel("y")
|
||||
ax1.grid(True, alpha=0.25)
|
||||
ax1.set_aspect("equal", adjustable="box")
|
||||
tt = np.linspace(0, 2 * np.pi, 360)
|
||||
ax1.plot(np.cos(tt), np.sin(tt), "k-", linewidth=1)
|
||||
ax0.set_xlim(-2200, 2200)
|
||||
ax0.set_ylim(-2200, 2200)
|
||||
ax1.set_xlim(-1.4, 1.4)
|
||||
ax1.set_ylim(-1.4, 1.4)
|
||||
|
||||
sc0 = ax0.scatter([], [], s=6, animated=True)
|
||||
(ell_line,) = ax0.plot([], [], "r-", linewidth=2, animated=True)
|
||||
sc1 = ax1.scatter([], [], s=6, animated=True)
|
||||
(radius_line,) = ax1.plot([], [], "g-", linewidth=2, animated=True)
|
||||
angle_text = ax1.text(
|
||||
0.02, 0.98, "", transform=ax1.transAxes, va="top", ha="left", fontsize=12, animated=True
|
||||
)
|
||||
|
||||
fig.canvas.draw()
|
||||
bg0 = fig.canvas.copy_from_bbox(ax0.bbox)
|
||||
bg1 = fig.canvas.copy_from_bbox(ax1.bbox)
|
||||
|
||||
# State
|
||||
joints_out = []
|
||||
joint_idx = 0
|
||||
phase = "ellipse"
|
||||
advance_requested = False
|
||||
zero_samples = []
|
||||
|
||||
def on_key(event):
|
||||
nonlocal advance_requested
|
||||
if event.key in ("n", "N", "enter", " "):
|
||||
advance_requested = True
|
||||
|
||||
fig.canvas.mpl_connect("key_press_event", on_key)
|
||||
|
||||
def reset_state():
|
||||
return {
|
||||
"xs": deque(maxlen=params.history),
|
||||
"ys": deque(maxlen=params.history),
|
||||
"xu": deque(maxlen=params.history),
|
||||
"yu": deque(maxlen=params.history),
|
||||
"win_s": deque(maxlen=params.median_window),
|
||||
"win_c": deque(maxlen=params.median_window),
|
||||
"ellipse_cache": None,
|
||||
"T": None,
|
||||
"center_fit": None,
|
||||
"have_transform": False,
|
||||
"latest_z": None,
|
||||
"last_fit": 0.0,
|
||||
}
|
||||
|
||||
state = reset_state()
|
||||
last_draw = 0.0
|
||||
name, pair = joint_list[joint_idx]
|
||||
fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE")
|
||||
ax0.set_title(f"{name} raw (filtered)")
|
||||
logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}")
|
||||
logger.info("Step 1: Move joint around to map ellipse, then press 'n'")
|
||||
|
||||
try:
|
||||
while plt.fignum_exists(fig.number):
|
||||
name, pair = joint_list[joint_idx]
|
||||
|
||||
# Handles calibration GUI state: ellipse → zero_pose → next joint -> ellipse -> ...
|
||||
if phase == "ellipse" and advance_requested and state["have_transform"]:
|
||||
joints_out.append(
|
||||
{
|
||||
"name": name,
|
||||
"center_fit": state["center_fit"].tolist(),
|
||||
"T": state["T"].tolist(),
|
||||
}
|
||||
)
|
||||
logger.info(f" -> Ellipse saved for {name}")
|
||||
phase, zero_samples, advance_requested = "zero_pose", [], False
|
||||
fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ZERO POSE")
|
||||
ax0.set_title(f"{name} - hold zero pose")
|
||||
fig.canvas.draw()
|
||||
bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox)
|
||||
logger.info(f"Step 2: Hold {name} in zero position, then press 'n'")
|
||||
|
||||
elif phase == "ellipse" and advance_requested and not state["have_transform"]:
|
||||
logger.info(" (Need valid fit first - keep moving the joint)")
|
||||
advance_requested = False
|
||||
|
||||
elif phase == "zero_pose" and advance_requested:
|
||||
if len(zero_samples) >= params.sample_count:
|
||||
zero_offset = float(np.mean(zero_samples[-params.sample_count :]))
|
||||
joints_out[-1]["zero_offset"] = zero_offset
|
||||
logger.info(f" -> {name} zero: {zero_offset:+.3f} rad ({np.degrees(zero_offset):+.1f}°)")
|
||||
joint_idx += 1
|
||||
advance_requested = False
|
||||
|
||||
if joint_idx >= len(joint_list):
|
||||
# All joints done
|
||||
calib = ExoskeletonCalibration(
|
||||
version=2,
|
||||
side=side,
|
||||
adc_max=2**12 - 1,
|
||||
joints=[
|
||||
ExoskeletonJointCalibration(
|
||||
name=j["name"],
|
||||
center_fit=j["center_fit"],
|
||||
T=j["T"],
|
||||
zero_offset=j.get("zero_offset", 0.0),
|
||||
)
|
||||
for j in joints_out
|
||||
],
|
||||
)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(save_path, "w") as f:
|
||||
json.dump(calib.to_dict(), f, indent=2)
|
||||
logger.info(f"Saved calibration to {save_path}")
|
||||
logger.info("Calibration complete!")
|
||||
plt.close(fig)
|
||||
return calib
|
||||
|
||||
# Next joint
|
||||
phase, state = "ellipse", reset_state()
|
||||
name, pair = joint_list[joint_idx]
|
||||
fig.canvas.manager.set_window_title(
|
||||
f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE"
|
||||
)
|
||||
ax0.set_title(f"{name} raw (filtered)")
|
||||
fig.canvas.draw()
|
||||
bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox)
|
||||
logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}")
|
||||
logger.info("Step 1: Move joint around to map ellipse, then press 'n'")
|
||||
else:
|
||||
logger.info(
|
||||
f" (Collecting samples: {len(zero_samples)}/{params.sample_count} - hold still)"
|
||||
)
|
||||
advance_requested = False
|
||||
|
||||
# Read sensor
|
||||
raw16 = read_raw_from_serial(ser)
|
||||
if raw16 is not None:
|
||||
x_raw, y_raw, s_raw, c_raw = read_joint_point(raw16, pair)
|
||||
|
||||
if phase == "ellipse":
|
||||
if state["have_transform"]:
|
||||
z = state["T"] @ (np.array([x_raw, y_raw]) - state["center_fit"])
|
||||
state["xu"].append(float(z[0]))
|
||||
state["yu"].append(float(z[1]))
|
||||
state["latest_z"] = (float(z[0]), float(z[1]))
|
||||
state["win_s"].append(s_raw)
|
||||
state["win_c"].append(c_raw)
|
||||
if len(state["win_s"]) >= max(3, params.median_window):
|
||||
state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2)
|
||||
state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2)
|
||||
else:
|
||||
jdata = joints_out[-1]
|
||||
z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"]))
|
||||
zero_samples.append(float(np.arctan2(z[1], z[0])))
|
||||
state["latest_z"] = (float(z[0]), float(z[1]))
|
||||
|
||||
# Ellipse fitting
|
||||
t = time.time()
|
||||
if (
|
||||
phase == "ellipse"
|
||||
and (t - state["last_fit"]) >= params.fit_every
|
||||
and len(state["xs"]) >= params.min_fit_points
|
||||
):
|
||||
xfit, yfit = select_fit_subset(state["xs"], state["ys"])
|
||||
if xfit is not None and len(xfit) >= params.min_fit_points:
|
||||
fit = fit_ellipse_opencv(xfit, yfit)
|
||||
if fit is not None:
|
||||
state["center_fit"] = fit["center"]
|
||||
state["T"] = np.diag([1.0 / fit["a"], 1.0 / fit["b"]]) @ fit["R"].T
|
||||
state["ellipse_cache"] = (fit["ex"], fit["ey"])
|
||||
state["have_transform"] = True
|
||||
state["last_fit"] = t
|
||||
|
||||
# Drawing
|
||||
if (t - last_draw) >= 1.0 / params.draw_hz:
|
||||
fig.canvas.restore_region(bg0)
|
||||
fig.canvas.restore_region(bg1)
|
||||
|
||||
if phase == "ellipse":
|
||||
sc0.set_offsets(np.c_[state["xs"], state["ys"]] if state["xs"] else np.empty((0, 2)))
|
||||
ax0.draw_artist(sc0)
|
||||
ell_line.set_data(*state["ellipse_cache"] if state["ellipse_cache"] else ([], []))
|
||||
ax0.draw_artist(ell_line)
|
||||
sc1.set_offsets(np.c_[state["xu"], state["yu"]] if state["xu"] else np.empty((0, 2)))
|
||||
ax1.draw_artist(sc1)
|
||||
if state["latest_z"]:
|
||||
zx, zy = state["latest_z"]
|
||||
radius_line.set_data([0.0, zx], [0.0, zy])
|
||||
ang = float(np.arctan2(zy, zx))
|
||||
angle_text.set_text(
|
||||
f"angle: {ang:+.3f} rad ({np.degrees(ang):+.1f}°)\nmove {name}, press 'n' to advance"
|
||||
)
|
||||
else:
|
||||
radius_line.set_data([], [])
|
||||
angle_text.set_text("(waiting for fit)")
|
||||
else:
|
||||
sc0.set_offsets(np.empty((0, 2)))
|
||||
ax0.draw_artist(sc0)
|
||||
ell_line.set_data([], [])
|
||||
ax0.draw_artist(ell_line)
|
||||
if state["latest_z"]:
|
||||
zx, zy = state["latest_z"]
|
||||
sc1.set_offsets([[zx, zy]])
|
||||
radius_line.set_data([0.0, zx], [0.0, zy])
|
||||
ang = float(np.arctan2(zy, zx))
|
||||
angle_text.set_text(
|
||||
f"Zero pose for {name}\nangle: {ang:+.3f} rad\nsamples: {len(zero_samples)}/{params.sample_count}\nhold still, press 'n'"
|
||||
)
|
||||
else:
|
||||
sc1.set_offsets(np.empty((0, 2)))
|
||||
radius_line.set_data([], [])
|
||||
angle_text.set_text("(waiting for data)")
|
||||
ax1.draw_artist(sc1)
|
||||
|
||||
ax1.draw_artist(radius_line)
|
||||
ax1.draw_artist(angle_text)
|
||||
fig.canvas.blit(ax0.bbox)
|
||||
fig.canvas.blit(ax1.bbox)
|
||||
fig.canvas.flush_events()
|
||||
last_draw = t
|
||||
|
||||
plt.pause(0.001)
|
||||
|
||||
finally:
|
||||
plt.close(fig)
|
||||
@@ -0,0 +1,353 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
IK helper for exoskeleton-to-G1 teleoperation. We map Exoskeleton joint angles to end-effector pose in world frame,
|
||||
visualizing the result in meshcat after calibration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
|
||||
from .exo_calib import JOINTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _frame_id(model, name: str) -> int | None:
|
||||
try:
|
||||
fid = model.getFrameId(name)
|
||||
return fid if 0 <= fid < model.nframes else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArmCfg:
|
||||
side: str # "left" | "right"
|
||||
urdf: str # exo_left.urdf / exo_right.urdf
|
||||
root: str # "exo_left" / "exo_right"
|
||||
g1_ee: str # "l_ee" / "r_ee"
|
||||
offset: np.ndarray # world offset for viz + target
|
||||
marker_prefix: str # "left" / "right"
|
||||
|
||||
|
||||
class Markers:
|
||||
"""Creates meshcat visualization primitives, showing end-effector frames of exoskeleton and G1"""
|
||||
|
||||
def __init__(self, viewer):
|
||||
self.v = viewer
|
||||
|
||||
def sphere(self, path: str, r: float, rgba: tuple[float, float, float, float]):
|
||||
import meshcat.geometry as mg
|
||||
|
||||
c = (int(rgba[0] * 255) << 16) | (int(rgba[1] * 255) << 8) | int(rgba[2] * 255)
|
||||
self.v[path].set_object(
|
||||
mg.Sphere(r),
|
||||
mg.MeshPhongMaterial(color=c, opacity=rgba[3], transparent=rgba[3] < 1.0),
|
||||
)
|
||||
|
||||
def axes(self, path: str, axis_len: float = 0.1, axis_w: int = 6):
|
||||
import meshcat.geometry as mg
|
||||
|
||||
pts = np.array(
|
||||
[[0, 0, 0], [axis_len, 0, 0], [0, 0, 0], [0, axis_len, 0], [0, 0, 0], [0, 0, axis_len]],
|
||||
dtype=np.float32,
|
||||
).T
|
||||
cols = np.array(
|
||||
[[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]],
|
||||
dtype=np.float32,
|
||||
).T
|
||||
self.v[path].set_object(
|
||||
mg.LineSegments(
|
||||
mg.PointsGeometry(position=pts, color=cols),
|
||||
mg.LineBasicMaterial(linewidth=axis_w, vertexColors=True),
|
||||
)
|
||||
)
|
||||
|
||||
def tf(self, path: str, mat: np.ndarray):
|
||||
self.v[path].set_transform(mat)
|
||||
|
||||
|
||||
class ExoskeletonIKHelper:
|
||||
"""
|
||||
- Loads G1 robot and exoskeleton URDF models via Pinocchio
|
||||
- Computes forward kinematics on exoskeleton to get end-effector poses
|
||||
- Solves inverse kinematics on G1 to match those poses
|
||||
- Provides meshcat visualization showing both robots and targets
|
||||
|
||||
Args:
|
||||
frozen_joints: List of G1 joint names to exclude from IK (kept at neutral).
|
||||
"""
|
||||
|
||||
def __init__(self, frozen_joints: list[str] | None = None):
|
||||
try:
|
||||
import pinocchio as pin
|
||||
except ImportError as e:
|
||||
raise ImportError("ik mode needs pinocchio: pip install pin") from e
|
||||
|
||||
self.pin = pin
|
||||
self.frozen_joints = frozen_joints or []
|
||||
|
||||
self.g1_ik = G1_29_ArmIK()
|
||||
self.robot_g1 = self.g1_ik.reduced_robot
|
||||
self.robot_g1.data = self.robot_g1.model.createData()
|
||||
self.q_g1 = pin.neutral(self.robot_g1.model)
|
||||
|
||||
assets_dir = os.path.join(self.g1_ik.repo_path, "assets")
|
||||
|
||||
self.frozen_idx = self._frozen_joint_indices()
|
||||
|
||||
self.arms = [
|
||||
ArmCfg(
|
||||
side="left",
|
||||
urdf=os.path.join(assets_dir, "exo_left.urdf"),
|
||||
root="exo_left",
|
||||
g1_ee="L_ee",
|
||||
offset=np.array([0.6, 0.3, 0.0]),
|
||||
marker_prefix="left",
|
||||
),
|
||||
ArmCfg(
|
||||
side="right",
|
||||
urdf=os.path.join(assets_dir, "exo_right.urdf"),
|
||||
root="exo_right",
|
||||
g1_ee="R_ee",
|
||||
offset=np.array([0.6, -0.3, 0.0]),
|
||||
marker_prefix="right",
|
||||
),
|
||||
]
|
||||
|
||||
self.exo = {} # side -> pin.RobotWrapper
|
||||
self.q_exo = {} # side -> q
|
||||
self.ee_id_exo = {} # side -> frame id
|
||||
self.qmap = {} # side -> {joint_name: q_idx}
|
||||
self.ee_id_g1 = {} # side -> frame id
|
||||
|
||||
self._load_exo_models(assets_dir)
|
||||
for a in self.arms:
|
||||
self.ee_id_g1[a.side] = _frame_id(self.robot_g1.model, a.g1_ee)
|
||||
|
||||
self.viewer = None
|
||||
self.markers: Markers | None = None
|
||||
self.viz_g1 = None
|
||||
self.viz_exo = {} # side -> viz
|
||||
|
||||
def _frozen_joint_indices(self) -> dict[str, int]:
|
||||
out = {}
|
||||
m = self.robot_g1.model
|
||||
for name in self.frozen_joints:
|
||||
if name in m.names:
|
||||
jid = m.getJointId(name)
|
||||
out[name] = m.idx_qs[jid]
|
||||
logger.info(f"freezing joint: {name} (q_idx={out[name]})")
|
||||
return out
|
||||
|
||||
def _find_exo_ee(self, model, ee_name: str = "ee") -> int:
|
||||
ee = _frame_id(model, ee_name)
|
||||
if ee is not None:
|
||||
return ee
|
||||
for fid in reversed(range(model.nframes)):
|
||||
if model.frames[fid].type == self.pin.FrameType.BODY:
|
||||
return fid
|
||||
return 0
|
||||
|
||||
def _build_joint_map(self, robot) -> dict[str, int]:
|
||||
m = robot.model
|
||||
return {n: m.idx_qs[m.getJointId(n)] for n in JOINTS if n in m.names}
|
||||
|
||||
def _load_exo_models(self, assets_dir: str):
|
||||
pin = self.pin
|
||||
for a in self.arms:
|
||||
if not os.path.exists(a.urdf):
|
||||
logger.warning(f"{a.side} exo urdf not found: {a.urdf}")
|
||||
continue
|
||||
r = pin.RobotWrapper.BuildFromURDF(a.urdf, assets_dir)
|
||||
self.exo[a.side] = r
|
||||
self.q_exo[a.side] = pin.neutral(r.model)
|
||||
self.ee_id_exo[a.side] = self._find_exo_ee(r.model)
|
||||
self.qmap[a.side] = self._build_joint_map(r)
|
||||
logger.info(f"loaded {a.side} exo urdf: {a.urdf}")
|
||||
|
||||
def init_visualization(self):
|
||||
"""
|
||||
Creates a browser-based visualization of exoskeleton and G1 robot,
|
||||
highlighting end-effector frames and target positions.
|
||||
"""
|
||||
try:
|
||||
from pinocchio.visualize import MeshcatVisualizer
|
||||
except ImportError as e:
|
||||
logger.warning(f"meshcat viz unavailable: {e}")
|
||||
return
|
||||
|
||||
# g1
|
||||
self.viz_g1 = MeshcatVisualizer(
|
||||
self.robot_g1.model, self.robot_g1.collision_model, self.robot_g1.visual_model
|
||||
)
|
||||
self.viz_g1.initViewer(open=True)
|
||||
self.viz_g1.loadViewerModel("g1")
|
||||
self.viz_g1.display(self.q_g1)
|
||||
|
||||
self.viewer = self.viz_g1.viewer
|
||||
self.markers = Markers(self.viewer)
|
||||
|
||||
# exos
|
||||
for a in self.arms:
|
||||
if a.side not in self.exo:
|
||||
continue
|
||||
r = self.exo[a.side]
|
||||
v = MeshcatVisualizer(r.model, r.collision_model, r.visual_model)
|
||||
v.initViewer(open=False)
|
||||
v.viewer = self.viewer
|
||||
v.loadViewerModel(a.root)
|
||||
offset_tf = np.eye(4)
|
||||
offset_tf[:3, 3] = a.offset
|
||||
self.viewer[a.root].set_transform(offset_tf)
|
||||
v.display(self.q_exo[a.side])
|
||||
self.viz_exo[a.side] = v
|
||||
|
||||
# markers
|
||||
for a in self.arms:
|
||||
p = a.marker_prefix
|
||||
self.markers.sphere(f"markers/{p}_exo_ee", 0.012, (0.2, 1.0, 0.2, 0.9))
|
||||
self.markers.sphere(f"markers/{p}_g1_ee", 0.015, (1.0, 0.2, 0.2, 0.9))
|
||||
self.markers.sphere(f"markers/{p}_ik_target", 0.015, (0.1, 0.3, 1.0, 0.9))
|
||||
self.markers.axes(f"markers/{p}_exo_axes", 0.06)
|
||||
self.markers.axes(f"markers/{p}_g1_axes", 0.08)
|
||||
|
||||
logger.info(f"meshcat viz initialized: {self.viewer.url()}")
|
||||
print(f"\nmeshcat url: {self.viewer.url()}\n")
|
||||
|
||||
def _fk_target_world(self, side: str, angles: dict[str, float]) -> np.ndarray | None:
|
||||
"""returns wrist frame target to be used for G1 IK in 4x4 homogeneous transform. Takes offset into account."""
|
||||
if side not in self.exo or not angles:
|
||||
return None
|
||||
|
||||
pin = self.pin
|
||||
q = self.q_exo[side]
|
||||
qmap = self.qmap[side]
|
||||
|
||||
for name, ang in angles.items():
|
||||
idx = qmap.get(name)
|
||||
if idx is not None:
|
||||
q[idx] = float(ang)
|
||||
|
||||
r = self.exo[side]
|
||||
pin.forwardKinematics(r.model, r.data, q)
|
||||
pin.updateFramePlacements(r.model, r.data)
|
||||
|
||||
ee = r.data.oMf[self.ee_id_exo[side]]
|
||||
target = np.eye(4)
|
||||
target[:3, :3] = ee.rotation
|
||||
# offset gets applied in world space
|
||||
cfg = next(a for a in self.arms if a.side == side)
|
||||
target[:3, 3] = cfg.offset + ee.translation
|
||||
return target
|
||||
|
||||
def update_visualization(self):
|
||||
if self.viewer is None or self.markers is None:
|
||||
return
|
||||
|
||||
pin = self.pin
|
||||
|
||||
# g1
|
||||
if self.viz_g1 is not None:
|
||||
self.viz_g1.display(self.q_g1)
|
||||
pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1)
|
||||
pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data)
|
||||
|
||||
for a in self.arms:
|
||||
fid = self.ee_id_g1.get(a.side)
|
||||
if fid is None:
|
||||
continue
|
||||
ee_tf = self.robot_g1.data.oMf[fid].homogeneous
|
||||
p = a.marker_prefix
|
||||
self.markers.tf(f"markers/{p}_g1_ee", ee_tf)
|
||||
self.markers.tf(f"markers/{p}_g1_axes", ee_tf)
|
||||
|
||||
# exos
|
||||
for a in self.arms:
|
||||
side = a.side
|
||||
v = self.viz_exo.get(side)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
v.display(self.q_exo[side])
|
||||
r = self.exo[side]
|
||||
pin.forwardKinematics(r.model, r.data, self.q_exo[side])
|
||||
pin.updateFramePlacements(r.model, r.data)
|
||||
|
||||
ee = r.data.oMf[self.ee_id_exo[side]]
|
||||
world_tf = (pin.SE3(np.eye(3), a.offset) * ee).homogeneous
|
||||
p = a.marker_prefix
|
||||
self.markers.tf(f"markers/{p}_exo_ee", world_tf)
|
||||
self.markers.tf(f"markers/{p}_exo_axes", world_tf)
|
||||
|
||||
target_tf = np.eye(4)
|
||||
target_tf[:3, :3] = ee.rotation
|
||||
target_tf[:3, 3] = a.offset + ee.translation
|
||||
self.markers.tf(f"markers/{p}_ik_target", target_tf)
|
||||
|
||||
def compute_g1_joints_from_exo(
|
||||
self,
|
||||
left_angles: dict[str, float],
|
||||
right_angles: dict[str, float],
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Performs FK on exoskeleton to get end-effector poses in world frame,
|
||||
after which it solves IK on G1 to return joint angles matching those poses in G1 motor order.
|
||||
"""
|
||||
pin = self.pin
|
||||
|
||||
targets = {
|
||||
"left": self._fk_target_world("left", left_angles),
|
||||
"right": self._fk_target_world("right", right_angles),
|
||||
}
|
||||
|
||||
# fallback to current g1 ee pose if missing target
|
||||
pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1)
|
||||
pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data)
|
||||
|
||||
for a in self.arms:
|
||||
if targets[a.side] is not None:
|
||||
continue
|
||||
fid = self.ee_id_g1.get(a.side)
|
||||
if fid is not None:
|
||||
targets[a.side] = self.robot_g1.data.oMf[fid].homogeneous
|
||||
|
||||
if targets["left"] is None or targets["right"] is None:
|
||||
logger.warning("missing ik targets, returning current pose")
|
||||
return {}
|
||||
|
||||
frozen_vals = {n: self.q_g1[i] for n, i in self.frozen_idx.items()}
|
||||
|
||||
self.q_g1, _ = self.g1_ik.solve_ik(
|
||||
targets["left"], targets["right"], current_lr_arm_motor_q=self.q_g1
|
||||
)
|
||||
|
||||
for n, i in self.frozen_idx.items():
|
||||
self.q_g1[i] = frozen_vals[n]
|
||||
|
||||
return {
|
||||
f"{j.name}.q": float(self.q_g1[i])
|
||||
for i, j in enumerate(G1_29_JointArmIndex)
|
||||
if i < len(self.q_g1)
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import serial
|
||||
|
||||
from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_raw16(line: bytes) -> list[int] | None:
|
||||
try:
|
||||
parts = line.decode("utf-8", errors="ignore").split()
|
||||
if len(parts) < 16:
|
||||
return None
|
||||
return [int(x) for x in parts[:16]]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def read_raw_from_serial(ser) -> list[int] | None:
|
||||
"""Read latest sample from serial; if buffer is backed up, keep only the newest."""
|
||||
last = None
|
||||
while ser.in_waiting > 0:
|
||||
b = ser.readline()
|
||||
if not b:
|
||||
break
|
||||
raw16 = parse_raw16(b)
|
||||
if raw16 is not None:
|
||||
last = raw16
|
||||
if last is None:
|
||||
b = ser.readline()
|
||||
if b:
|
||||
last = parse_raw16(b)
|
||||
return last
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExoskeletonArm:
|
||||
port: str
|
||||
calibration_fpath: Path
|
||||
side: str
|
||||
baud_rate: int = 115200
|
||||
|
||||
_ser: serial.Serial | None = None
|
||||
calibration: ExoskeletonCalibration | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.calibration_fpath.is_file():
|
||||
self._load_calibration()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._ser is not None and getattr(self._ser, "is_open", False)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.calibration is not None
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
return
|
||||
try:
|
||||
self._ser = serial.Serial(self.port, self.baud_rate, timeout=0.02)
|
||||
self._ser.reset_input_buffer()
|
||||
logger.info(f"connected: {self.port}")
|
||||
except serial.SerialException as e:
|
||||
raise ConnectionError(f"failed to connect to {self.port}: {e}") from e
|
||||
|
||||
if calibrate and not self.is_calibrated:
|
||||
self.calibrate()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self._ser:
|
||||
try:
|
||||
self._ser.close()
|
||||
finally:
|
||||
self._ser = None
|
||||
|
||||
def _load_calibration(self) -> None:
|
||||
try:
|
||||
data = json.loads(self.calibration_fpath.read_text())
|
||||
self.calibration = ExoskeletonCalibration.from_dict(data)
|
||||
logger.info(f"loaded calibration: {self.calibration_fpath}")
|
||||
except Exception as e:
|
||||
logger.warning(f"failed to load calibration: {e}")
|
||||
|
||||
def read_raw(self) -> list[int] | None:
|
||||
if not self._ser:
|
||||
return None
|
||||
return read_raw_from_serial(self._ser)
|
||||
|
||||
def get_angles(self) -> dict[str, float]:
|
||||
if not self.calibration:
|
||||
raise RuntimeError("exoskeleton not calibrated")
|
||||
raw = self.read_raw()
|
||||
return {} if raw is None else exo_raw_to_angles(raw, self.calibration)
|
||||
|
||||
def calibrate(self) -> None:
|
||||
ser = self._ser
|
||||
self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath)
|
||||
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_unitree_g1 import UnitreeG1TeleoperatorConfig
|
||||
from .exo_ik import ExoskeletonIKHelper
|
||||
from .exo_serial import ExoskeletonArm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnitreeG1Teleoperator(Teleoperator):
|
||||
"""
|
||||
Bimanual exoskeleton arms teleoperator for Unitree G1 arms.
|
||||
|
||||
Uses inverse kinematics: exoskeleton FK computes end-effector pose,
|
||||
G1 IK solves for joint angles.
|
||||
"""
|
||||
|
||||
config_class = UnitreeG1TeleoperatorConfig
|
||||
name = "unitree_g1"
|
||||
|
||||
def __init__(self, config: UnitreeG1TeleoperatorConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Setup calibration directory
|
||||
self.calibration_dir = (
|
||||
config.calibration_dir
|
||||
if config.calibration_dir
|
||||
else HF_LEROBOT_CALIBRATION / TELEOPERATORS / self.name
|
||||
)
|
||||
self.calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
left_id = f"{config.id}_left" if config.id else "left"
|
||||
right_id = f"{config.id}_right" if config.id else "right"
|
||||
|
||||
# Create exoskeleton arm instances
|
||||
self.left_arm = ExoskeletonArm(
|
||||
port=config.left_arm_config.port,
|
||||
baud_rate=config.left_arm_config.baud_rate,
|
||||
calibration_fpath=self.calibration_dir / f"{left_id}.json",
|
||||
side="left",
|
||||
)
|
||||
self.right_arm = ExoskeletonArm(
|
||||
port=config.right_arm_config.port,
|
||||
baud_rate=config.right_arm_config.baud_rate,
|
||||
calibration_fpath=self.calibration_dir / f"{right_id}.json",
|
||||
side="right",
|
||||
)
|
||||
|
||||
self.ik_helper: ExoskeletonIKHelper | None = None
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{name}.q": float for name in self._g1_joint_names}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()]
|
||||
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
|
||||
logger.info("IK helper initialized")
|
||||
|
||||
def calibrate(self) -> None:
|
||||
if not self.left_arm.is_calibrated:
|
||||
logger.info("Starting calibration for left arm...")
|
||||
self.left_arm.calibrate()
|
||||
else:
|
||||
logger.info("Left arm already calibrated. Skipping.")
|
||||
|
||||
if not self.right_arm.is_calibrated:
|
||||
logger.info("Starting calibration for right arm...")
|
||||
self.right_arm.calibrate()
|
||||
else:
|
||||
logger.info("Right arm already calibrated. Skipping.")
|
||||
|
||||
logger.info("Starting visualization to verify calibration...")
|
||||
self.run_visualization_loop()
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
left_angles = self.left_arm.get_angles()
|
||||
right_angles = self.right_arm.get_angles()
|
||||
return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Exoskeleton arms do not support feedback")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
def run_visualization_loop(self):
|
||||
"""Run interactive Meshcat visualization loop to verify tracking."""
|
||||
if self.ik_helper is None:
|
||||
frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()]
|
||||
self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints)
|
||||
|
||||
self.ik_helper.init_visualization()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Visualization running! Move the exoskeletons to test tracking.")
|
||||
print("Press Ctrl+C to exit.")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
left_angles = self.left_arm.get_angles()
|
||||
right_angles = self.right_arm.get_angles()
|
||||
|
||||
self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles)
|
||||
self.ik_helper.update_visualization()
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nVisualization stopped.")
|
||||
|
||||
@cached_property
|
||||
def _g1_joint_names(self) -> list[str]:
|
||||
return [joint.name for joint in G1_29_JointIndex]
|
||||
@@ -13,12 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .teleoperator import Teleoperator
|
||||
|
||||
|
||||
class TeleopEvents(Enum):
|
||||
@@ -31,7 +33,7 @@ class TeleopEvents(Enum):
|
||||
TERMINATE_EPISODE = "terminate_episode"
|
||||
|
||||
|
||||
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
if config.type == "keyboard":
|
||||
from .keyboard import KeyboardTeleop
|
||||
@@ -73,6 +75,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from .homunculus import HomunculusArm
|
||||
|
||||
return HomunculusArm(config)
|
||||
elif config.type == "unitree_g1":
|
||||
from .unitree_g1 import UnitreeG1Teleoperator
|
||||
|
||||
return UnitreeG1Teleoperator(config)
|
||||
elif config.type == "bi_so_leader":
|
||||
from .bi_so_leader import BiSOLeader
|
||||
|
||||
@@ -81,8 +87,16 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from .reachy2_teleoperator import Reachy2Teleoperator
|
||||
|
||||
return Reachy2Teleoperator(config)
|
||||
elif config.type == "openarm_leader":
|
||||
from .openarm_leader import OpenArmLeader
|
||||
|
||||
return OpenArmLeader(config)
|
||||
elif config.type == "bi_openarm_leader":
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
|
||||
return BiOpenArmLeader(config)
|
||||
else:
|
||||
try:
|
||||
return cast(Teleoperator, make_device_from_device_class(config))
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error creating robot with config {config}: {e}") from e
|
||||
|
||||
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask"
|
||||
|
||||
ACTION = "action"
|
||||
ACTION_PREFIX = ACTION + "."
|
||||
|
||||
+127
-57
@@ -20,7 +20,9 @@
|
||||
# ```
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
@@ -28,6 +30,50 @@ from lerobot.cameras.configs import Cv2Rotation
|
||||
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
RealVideoCapture = cv2.VideoCapture
|
||||
|
||||
|
||||
class MockLoopingVideoCapture:
|
||||
"""
|
||||
Wraps the real OpenCV VideoCapture.
|
||||
Motivation: cv2.VideoCapture(file.png) is only valid for one read.
|
||||
Strategy: Read the file once & return the cached frame for subsequent reads.
|
||||
Consequence: No recurrent I/O operations, but we keep the test artifacts simple.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
args_clean = [str(a) if isinstance(a, Path) else a for a in args]
|
||||
self._real_vc = RealVideoCapture(*args_clean, **kwargs)
|
||||
self._cached_frame = None
|
||||
|
||||
def read(self):
|
||||
ret, frame = self._real_vc.read()
|
||||
|
||||
if ret:
|
||||
self._cached_frame = frame
|
||||
return ret, frame
|
||||
|
||||
if not ret and self._cached_frame is not None:
|
||||
return True, self._cached_frame.copy()
|
||||
|
||||
return ret, frame
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._real_vc, name)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_opencv_videocapture():
|
||||
"""
|
||||
Automatically patches cv2.VideoCapture for all tests.
|
||||
"""
|
||||
module_path = OpenCVCamera.__module__
|
||||
target = f"{module_path}.cv2.VideoCapture"
|
||||
|
||||
with patch(target, new=MockLoopingVideoCapture):
|
||||
yield
|
||||
|
||||
|
||||
# NOTE(Steven): more tests + assertions?
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
|
||||
DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png"
|
||||
@@ -43,25 +89,22 @@ def test_abc_implementation():
|
||||
|
||||
|
||||
def test_connect():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
|
||||
|
||||
camera.connect(warmup=False)
|
||||
|
||||
assert camera.is_connected
|
||||
with OpenCVCamera(config) as camera:
|
||||
assert camera.is_connected
|
||||
|
||||
|
||||
def test_connect_already_connected():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
|
||||
|
||||
with pytest.raises(DeviceAlreadyConnectedError):
|
||||
camera.connect(warmup=False)
|
||||
with OpenCVCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError):
|
||||
camera.connect()
|
||||
|
||||
|
||||
def test_connect_invalid_camera_path():
|
||||
config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png")
|
||||
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
@@ -74,27 +117,25 @@ def test_invalid_width_connect():
|
||||
width=99999, # Invalid width to trigger error
|
||||
height=480,
|
||||
)
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
camera = OpenCVCamera(config)
|
||||
with pytest.raises(RuntimeError):
|
||||
camera.connect(warmup=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
||||
def test_read(index_or_path):
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0)
|
||||
|
||||
img = camera.read()
|
||||
|
||||
assert isinstance(img, np.ndarray)
|
||||
with OpenCVCamera(config) as camera:
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
|
||||
def test_read_before_connect():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
|
||||
camera = OpenCVCamera(config)
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read()
|
||||
|
||||
@@ -119,32 +160,22 @@ def test_disconnect_before_connect():
|
||||
|
||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
||||
def test_async_read(index_or_path):
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0)
|
||||
|
||||
try:
|
||||
with OpenCVCamera(config) as camera:
|
||||
img = camera.async_read()
|
||||
|
||||
assert camera.thread is not None
|
||||
assert camera.thread.is_alive()
|
||||
assert isinstance(img, np.ndarray)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
|
||||
|
||||
|
||||
def test_async_read_timeout():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
|
||||
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(timeout_ms=0)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
with OpenCVCamera(config) as camera, pytest.raises(TimeoutError):
|
||||
camera.async_read(timeout_ms=0) # consumes any available frame by then
|
||||
camera.async_read(timeout_ms=0) # request immediately another one
|
||||
|
||||
|
||||
def test_async_read_before_connect():
|
||||
@@ -155,6 +186,50 @@ def test_async_read_before_connect():
|
||||
_ = camera.async_read()
|
||||
|
||||
|
||||
def test_read_latest():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
|
||||
|
||||
with OpenCVCamera(config) as camera:
|
||||
# ensure at least one fresh frame is captured
|
||||
frame = camera.read()
|
||||
latest = camera.read_latest()
|
||||
|
||||
assert isinstance(latest, np.ndarray)
|
||||
assert latest.shape == frame.shape
|
||||
|
||||
|
||||
def test_read_latest_before_connect():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH)
|
||||
|
||||
camera = OpenCVCamera(config)
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read_latest()
|
||||
|
||||
|
||||
def test_read_latest_high_frequency():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
|
||||
|
||||
with OpenCVCamera(config) as camera:
|
||||
# prime to ensure frames are available
|
||||
ref = camera.read()
|
||||
|
||||
for _ in range(20):
|
||||
latest = camera.read_latest()
|
||||
assert isinstance(latest, np.ndarray)
|
||||
assert latest.shape == ref.shape
|
||||
|
||||
|
||||
def test_read_latest_too_old():
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0)
|
||||
|
||||
with OpenCVCamera(config) as camera:
|
||||
# prime to ensure frames are available
|
||||
_ = camera.read()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
_ = camera.read_latest(max_age_ms=0) # immediately too old
|
||||
|
||||
|
||||
def test_fourcc_configuration():
|
||||
"""Test FourCC configuration validation and application."""
|
||||
|
||||
@@ -181,18 +256,15 @@ def test_fourcc_configuration():
|
||||
|
||||
def test_fourcc_with_camera():
|
||||
"""Test FourCC functionality with actual camera connection."""
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG")
|
||||
camera = OpenCVCamera(config)
|
||||
config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG", warmup_s=0)
|
||||
|
||||
# Connect should work with MJPG specified
|
||||
camera.connect(warmup=False)
|
||||
assert camera.is_connected
|
||||
with OpenCVCamera(config) as camera:
|
||||
assert camera.is_connected
|
||||
|
||||
# Read should work normally
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
camera.disconnect()
|
||||
# Read should work normally
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES)
|
||||
@@ -211,18 +283,16 @@ def test_rotation(rotation, index_or_path):
|
||||
dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png)
|
||||
original_width, original_height = map(int, dimensions.split("x"))
|
||||
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation, warmup_s=0)
|
||||
with OpenCVCamera(config) as camera:
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
|
||||
assert camera.width == original_height
|
||||
assert camera.height == original_width
|
||||
assert img.shape[:2] == (original_width, original_height)
|
||||
else:
|
||||
assert camera.width == original_width
|
||||
assert camera.height == original_height
|
||||
assert img.shape[:2] == (original_height, original_width)
|
||||
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
|
||||
assert camera.width == original_height
|
||||
assert camera.height == original_width
|
||||
assert img.shape[:2] == (original_width, original_height)
|
||||
else:
|
||||
assert camera.width == original_width
|
||||
assert camera.height == original_height
|
||||
assert img.shape[:2] == (original_height, original_width)
|
||||
|
||||
@@ -150,6 +150,44 @@ def test_async_read_before_connect(camera):
|
||||
_ = camera.async_read()
|
||||
|
||||
|
||||
def test_read_latest(camera):
|
||||
camera.connect()
|
||||
|
||||
frame = camera.read()
|
||||
latest = camera.read_latest()
|
||||
|
||||
assert isinstance(latest, np.ndarray)
|
||||
assert latest.shape == frame.shape
|
||||
|
||||
|
||||
def test_read_latest_before_connect(camera):
|
||||
# camera fixture yields an unconnected camera instance
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read_latest()
|
||||
|
||||
|
||||
def test_read_latest_high_frequency(camera):
|
||||
camera.connect()
|
||||
|
||||
# prime to ensure frames are available
|
||||
ref = camera.read()
|
||||
|
||||
for _ in range(20):
|
||||
latest = camera.read_latest()
|
||||
assert isinstance(latest, np.ndarray)
|
||||
assert latest.shape == ref.shape
|
||||
|
||||
|
||||
def test_read_latest_too_old(camera):
|
||||
camera.connect()
|
||||
|
||||
# prime to ensure frames are available
|
||||
_ = camera.read()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
_ = camera.read_latest(max_age_ms=0) # immediately too old
|
||||
|
||||
|
||||
def test_wrong_camera_name():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="wrong-name", image_type="left")
|
||||
|
||||
@@ -62,19 +62,15 @@ def test_abc_implementation():
|
||||
|
||||
|
||||
def test_connect():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
|
||||
|
||||
camera.connect(warmup=False)
|
||||
assert camera.is_connected
|
||||
with RealSenseCamera(config) as camera:
|
||||
assert camera.is_connected
|
||||
|
||||
|
||||
def test_connect_already_connected():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
with pytest.raises(DeviceAlreadyConnectedError):
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
|
||||
with RealSenseCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError):
|
||||
camera.connect(warmup=False)
|
||||
|
||||
|
||||
@@ -96,12 +92,10 @@ def test_invalid_width_connect():
|
||||
|
||||
|
||||
def test_read():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
|
||||
with RealSenseCamera(config) as camera:
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
|
||||
# TODO(Steven): Fix this test for the latest version of pyrealsense2.
|
||||
@@ -142,32 +136,21 @@ def test_disconnect_before_connect():
|
||||
|
||||
|
||||
def test_async_read():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
|
||||
|
||||
try:
|
||||
with RealSenseCamera(config) as camera:
|
||||
img = camera.async_read()
|
||||
|
||||
assert camera.thread is not None
|
||||
assert camera.thread.is_alive()
|
||||
assert isinstance(img, np.ndarray)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends
|
||||
|
||||
|
||||
def test_async_read_timeout():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(timeout_ms=0)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
|
||||
with RealSenseCamera(config) as camera, pytest.raises(TimeoutError):
|
||||
camera.async_read(timeout_ms=0) # consumes any available frame by then
|
||||
camera.async_read(timeout_ms=0) # request immediately another one
|
||||
|
||||
|
||||
def test_async_read_before_connect():
|
||||
@@ -178,6 +161,47 @@ def test_async_read_before_connect():
|
||||
_ = camera.async_read()
|
||||
|
||||
|
||||
def test_read_latest():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
|
||||
with RealSenseCamera(config) as camera:
|
||||
img = camera.read()
|
||||
latest = camera.read_latest()
|
||||
|
||||
assert isinstance(latest, np.ndarray)
|
||||
assert latest.shape == img.shape
|
||||
|
||||
|
||||
def test_read_latest_high_frequency():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0)
|
||||
with RealSenseCamera(config) as camera:
|
||||
# prime with one read to ensure frames are available
|
||||
ref = camera.read()
|
||||
|
||||
for _ in range(20):
|
||||
latest = camera.read_latest()
|
||||
assert isinstance(latest, np.ndarray)
|
||||
assert latest.shape == ref.shape
|
||||
|
||||
|
||||
def test_read_latest_before_connect():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
camera = RealSenseCamera(config)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read_latest()
|
||||
|
||||
|
||||
def test_read_latest_too_old():
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042")
|
||||
|
||||
with RealSenseCamera(config) as camera:
|
||||
# prime to ensure frames are available
|
||||
_ = camera.read()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
_ = camera.read_latest(max_age_ms=0) # immediately too old
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rotation",
|
||||
[
|
||||
@@ -189,18 +213,16 @@ def test_async_read_before_connect():
|
||||
ids=["no_rot", "rot90", "rot180", "rot270"],
|
||||
)
|
||||
def test_rotation(rotation):
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect(warmup=False)
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation, warmup_s=0)
|
||||
with RealSenseCamera(config) as camera:
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
img = camera.read()
|
||||
assert isinstance(img, np.ndarray)
|
||||
|
||||
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
|
||||
assert camera.width == 480
|
||||
assert camera.height == 640
|
||||
assert img.shape[:2] == (640, 480)
|
||||
else:
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
assert img.shape[:2] == (480, 640)
|
||||
if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270):
|
||||
assert camera.width == 480
|
||||
assert camera.height == 640
|
||||
assert img.shape[:2] == (640, 480)
|
||||
else:
|
||||
assert camera.width == 640
|
||||
assert camera.height == 480
|
||||
assert img.shape[:2] == (480, 640)
|
||||
|
||||
@@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
assert img.shape[0] == 3, f"Image {image_key} should have 3 channels"
|
||||
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory):
|
||||
"""Regression test for aggregating a dataset that is itself a result of a previous merge.
|
||||
|
||||
This test reproduces the bug where merging datasets with multiple parquet files
|
||||
(e.g., from a previous merge with file rotation) would cause FileNotFoundError
|
||||
because metadata file indices were incorrectly preserved instead of being mapped
|
||||
to their actual destination files.
|
||||
|
||||
The fix adds src_to_dst tracking in aggregate_data() to correctly map source
|
||||
file indices to destination file indices.
|
||||
"""
|
||||
# Step 1: Create datasets A and B
|
||||
ds_a = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds_a",
|
||||
repo_id=f"{DUMMY_REPO_ID}_a",
|
||||
total_episodes=4,
|
||||
total_frames=200,
|
||||
)
|
||||
ds_b = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds_b",
|
||||
repo_id=f"{DUMMY_REPO_ID}_b",
|
||||
total_episodes=4,
|
||||
total_frames=200,
|
||||
)
|
||||
|
||||
# Step 2: Merge A+B into AB with small file size to force multiple files
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_a.repo_id, ds_b.repo_id],
|
||||
roots=[ds_a.root, ds_b.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_ab",
|
||||
aggr_root=tmp_path / "ds_ab",
|
||||
data_files_size_in_mb=0.01, # Force file rotation
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "ds_ab")
|
||||
ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab")
|
||||
|
||||
# Verify AB has multiple data files (file rotation occurred)
|
||||
ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet"))
|
||||
assert len(ab_data_files) > 1, "First merge should create multiple parquet files"
|
||||
|
||||
# Step 3: Create dataset C
|
||||
ds_c = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds_c",
|
||||
repo_id=f"{DUMMY_REPO_ID}_c",
|
||||
total_episodes=2,
|
||||
total_frames=100,
|
||||
)
|
||||
|
||||
# Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_ab.repo_id, ds_c.repo_id],
|
||||
roots=[ds_ab.root, ds_c.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_abc",
|
||||
aggr_root=tmp_path / "ds_abc",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "ds_abc")
|
||||
ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc")
|
||||
|
||||
# Step 5: Verify all data files referenced in metadata actually exist
|
||||
for ep_idx in range(ds_abc.num_episodes):
|
||||
data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx)
|
||||
assert data_file_path.exists(), (
|
||||
f"Episode {ep_idx} references non-existent file: {data_file_path}\n"
|
||||
"This indicates the src_to_dst mapping fix is not working correctly."
|
||||
)
|
||||
|
||||
# Step 6: Verify we can iterate through the entire dataset without FileNotFoundError
|
||||
expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes
|
||||
expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames
|
||||
|
||||
assert ds_abc.num_episodes == expected_episodes
|
||||
assert ds_abc.num_frames == expected_frames
|
||||
|
||||
# This would raise FileNotFoundError before the fix
|
||||
assert_dataset_iteration_works(ds_abc)
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
@@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
|
||||
|
||||
def test_modify_tasks_single_task_for_all(sample_dataset):
|
||||
"""Test setting a single task for all episodes."""
|
||||
new_task = "Pick up the cube and place it"
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, new_task=new_task)
|
||||
|
||||
# Verify all episodes have the new task
|
||||
assert len(modified_dataset.meta.tasks) == 1
|
||||
assert new_task in modified_dataset.meta.tasks.index
|
||||
|
||||
# Verify task_index is 0 for all frames (only one task)
|
||||
for i in range(len(modified_dataset)):
|
||||
item = modified_dataset[i]
|
||||
assert item["task_index"].item() == 0
|
||||
assert item["task"] == new_task
|
||||
|
||||
|
||||
def test_modify_tasks_episode_specific(sample_dataset):
|
||||
"""Test setting different tasks for specific episodes."""
|
||||
episode_tasks = {
|
||||
0: "Task A",
|
||||
1: "Task B",
|
||||
2: "Task A",
|
||||
3: "Task C",
|
||||
4: "Task B",
|
||||
}
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||
|
||||
# Verify correct number of unique tasks
|
||||
unique_tasks = set(episode_tasks.values())
|
||||
assert len(modified_dataset.meta.tasks) == len(unique_tasks)
|
||||
|
||||
# Verify each episode has the correct task
|
||||
for ep_idx, expected_task in episode_tasks.items():
|
||||
ep_data = modified_dataset.meta.episodes[ep_idx]
|
||||
assert ep_data["tasks"][0] == expected_task
|
||||
|
||||
|
||||
def test_modify_tasks_default_with_overrides(sample_dataset):
|
||||
"""Test setting a default task with specific overrides."""
|
||||
default_task = "Default task"
|
||||
override_task = "Special task"
|
||||
episode_tasks = {2: override_task, 4: override_task}
|
||||
|
||||
modified_dataset = modify_tasks(
|
||||
sample_dataset,
|
||||
new_task=default_task,
|
||||
episode_tasks=episode_tasks,
|
||||
)
|
||||
|
||||
# Verify correct number of unique tasks
|
||||
assert len(modified_dataset.meta.tasks) == 2
|
||||
assert default_task in modified_dataset.meta.tasks.index
|
||||
assert override_task in modified_dataset.meta.tasks.index
|
||||
|
||||
# Verify episodes have correct tasks
|
||||
for ep_idx in range(5):
|
||||
ep_data = modified_dataset.meta.episodes[ep_idx]
|
||||
if ep_idx in episode_tasks:
|
||||
assert ep_data["tasks"][0] == override_task
|
||||
else:
|
||||
assert ep_data["tasks"][0] == default_task
|
||||
|
||||
|
||||
def test_modify_tasks_no_task_specified(sample_dataset):
|
||||
"""Test error when no task is specified."""
|
||||
with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"):
|
||||
modify_tasks(sample_dataset)
|
||||
|
||||
|
||||
def test_modify_tasks_invalid_episode_indices(sample_dataset):
|
||||
"""Test error with invalid episode indices."""
|
||||
with pytest.raises(ValueError, match="Invalid episode indices"):
|
||||
modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"})
|
||||
|
||||
|
||||
def test_modify_tasks_updates_info_json(sample_dataset):
|
||||
"""Test that total_tasks is updated in info.json."""
|
||||
episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"}
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||
|
||||
# Verify total_tasks is updated
|
||||
assert modified_dataset.meta.total_tasks == 3
|
||||
|
||||
|
||||
def test_modify_tasks_preserves_other_metadata(sample_dataset):
|
||||
"""Test that modifying tasks preserves other metadata."""
|
||||
original_frames = sample_dataset.meta.total_frames
|
||||
original_episodes = sample_dataset.meta.total_episodes
|
||||
original_fps = sample_dataset.meta.fps
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
|
||||
|
||||
# Verify other metadata is preserved
|
||||
assert modified_dataset.meta.total_frames == original_frames
|
||||
assert modified_dataset.meta.total_episodes == original_episodes
|
||||
assert modified_dataset.meta.fps == original_fps
|
||||
|
||||
|
||||
def test_modify_tasks_task_index_correct(sample_dataset):
|
||||
"""Test that task_index values are correct in data files."""
|
||||
# Create tasks that will have predictable indices (sorted alphabetically)
|
||||
episode_tasks = {
|
||||
0: "Alpha task", # Will be index 0
|
||||
1: "Beta task", # Will be index 1
|
||||
2: "Alpha task", # Will be index 0
|
||||
3: "Gamma task", # Will be index 2
|
||||
4: "Beta task", # Will be index 1
|
||||
}
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||
|
||||
# Verify task indices are correct
|
||||
task_to_expected_idx = {
|
||||
"Alpha task": 0,
|
||||
"Beta task": 1,
|
||||
"Gamma task": 2,
|
||||
}
|
||||
|
||||
for i in range(len(modified_dataset)):
|
||||
item = modified_dataset[i]
|
||||
ep_idx = item["episode_index"].item()
|
||||
expected_task = episode_tasks[ep_idx]
|
||||
expected_idx = task_to_expected_idx[expected_task]
|
||||
assert item["task_index"].item() == expected_idx
|
||||
assert item["task"] == expected_task
|
||||
|
||||
|
||||
def test_modify_tasks_in_place(sample_dataset):
|
||||
"""Test that modify_tasks modifies the dataset in-place."""
|
||||
original_root = sample_dataset.root
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, new_task="New task")
|
||||
|
||||
# Verify same instance is returned and root is unchanged
|
||||
assert modified_dataset is sample_dataset
|
||||
assert modified_dataset.root == original_root
|
||||
|
||||
|
||||
def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset):
|
||||
"""Test that original tasks are kept when using episode_tasks without new_task."""
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
|
||||
# Ensure episodes metadata is loaded
|
||||
if sample_dataset.meta.episodes is None:
|
||||
sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root)
|
||||
|
||||
# Get original tasks for episodes not being overridden
|
||||
original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0]
|
||||
original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0]
|
||||
|
||||
# Only override episodes 2, 3, 4
|
||||
episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"}
|
||||
|
||||
modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks)
|
||||
|
||||
# Verify original tasks are kept for episodes 0 and 1
|
||||
assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0
|
||||
assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1
|
||||
|
||||
# Verify new tasks for overridden episodes
|
||||
assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A"
|
||||
assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B"
|
||||
assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A"
|
||||
|
||||
|
||||
def test_convert_image_to_video_dataset(tmp_path):
|
||||
"""Test converting lerobot/pusht_image dataset to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
@@ -390,6 +390,30 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
def test_make_transform_from_config_with_v2_resize(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
assert isinstance(tf, v2.Resize)
|
||||
output = tf(img_tensor)
|
||||
assert output.shape[-2:] == (32, 32)
|
||||
|
||||
|
||||
def test_make_transform_from_config_with_v2_identity(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformConfig(type="Identity", kwargs={})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
assert isinstance(tf, v2.Identity)
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_make_transform_from_config_invalid_type():
|
||||
tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={})
|
||||
with pytest.raises(ValueError, match="not valid"):
|
||||
make_transform_from_config(tf_cfg)
|
||||
|
||||
|
||||
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user