mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 32fc5504cc | |||
| fc8a388a25 | |||
| 3c84d271d5 | |||
| 1ba3975020 | |||
| 35363c5798 | |||
| 778db19a17 | |||
| d2d01399d6 | |||
| 5eba4ce6f4 | |||
| cca0296cd6 | |||
| 489cb7b6b9 | |||
| e14bdf57d0 | |||
| 97e7e0f9ed | |||
| 0f39248445 | |||
| a6370dd783 | |||
| 14a15f90e7 | |||
| 9c24a09665 |
@@ -101,9 +101,11 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: |
|
||||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
|
||||
github.event_name == 'push' ||
|
||||
github.event_name == 'workflow_dispatch'
|
||||
github.repository == 'huggingface/lerobot' && (
|
||||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
|
||||
github.event_name == 'push' ||
|
||||
github.event_name == 'workflow_dispatch'
|
||||
)
|
||||
outputs:
|
||||
image_tag: ${{ steps.set_tag.outputs.image_tag }}
|
||||
env:
|
||||
|
||||
@@ -91,6 +91,7 @@ jobs:
|
||||
name: Build and Push Docker
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
|
||||
env:
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
# Action tokenizer benchmark
|
||||
|
||||
## Questions
|
||||
|
||||
What is the trade-off between:
|
||||
|
||||
- **Compression**: how many tokens are needed to represent an action chunk (e.g. horizon × action_dim floats)?
|
||||
- **Reconstruction quality**: how well does encode-then-decode preserve the original actions?
|
||||
- **Speed**: how long does encoding and decoding take per chunk?
|
||||
|
||||
How to choose an action tokenizer?
|
||||
|
||||
- Which tokenizer architecture (e.g. dct + BPE, DCT + BPE)?
|
||||
- Which **action horizon** and **encoded dimensions** to use?
|
||||
- Which **normalization** (QUANTILES, MEAN_STD, MIN_MAX) and **delta transform** (relative vs absolute actions)?
|
||||
- How do reconstruction error and compression ratio vary across datasets and tokenizer settings?
|
||||
|
||||
This benchmark loads action chunks from a LeRobot dataset using the same pipeline as `lerobot-train-tokenizer`, runs a trained action tokenizer in encode/decode mode, and reports reconstruction error, compression stats, and timing. Results are saved as JSON under `outputs/` for comparison and analysis.
|
||||
|
||||
## Variables
|
||||
|
||||
**Dataset & chunking**
|
||||
|
||||
- **repo_id**: LeRobot dataset (e.g. `lerobot/pusht`). Action statistics and normalization are taken from the dataset metadata when available.
|
||||
- **action_horizon**: Number of future steps per action chunk (must match the tokenizer’s training).
|
||||
- **encoded_dims**: Dimension ranges to encode (e.g. `0:6` or `0:6,7:14`). Must match the tokenizer.
|
||||
- **max_episodes**: Cap on episodes to load (default: all).
|
||||
- **sample_fraction**: Fraction of chunks to sample per episode (default `0.2`) to keep runtime manageable.
|
||||
|
||||
**Transform & normalization**
|
||||
|
||||
- **normalization_mode**: `IDENTITY`, `MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`. Should match the tokenizer’s training.
|
||||
- **delta_dims**: Comma-separated dimension indices for delta (relative) transform.
|
||||
- **use_delta_transform**: Whether to convert actions to relative to current state for those dimensions.
|
||||
- **state_key**: Dataset key for state (e.g. `observation.state`) used when applying delta transform.
|
||||
|
||||
**Tokenizer & evaluation**
|
||||
|
||||
- **action_tokenizer_path**: Path or HuggingFace repo id of the trained tokenizer (e.g. `outputs/wavetoken`).
|
||||
- **max_chunks_for_reconstruction**: Max number of chunks to use for reconstruction and timing (default `500`) to limit runtime.
|
||||
|
||||
### Main parameters
|
||||
|
||||
| parameter | default | description |
|
||||
| -------------------------------- | ---------------------------- | ------------------------------------------------ |
|
||||
| **action_tokenizer_path** | (required) | Path or Hub id of the trained action tokenizer. |
|
||||
| **repo_id** | (required) | LeRobot dataset repo id. |
|
||||
| **action_horizon** | `10` | Future steps per chunk. |
|
||||
| **encoded_dims** | `0:6` | Dimension ranges to encode (e.g. `0:6,7:14`). |
|
||||
| **normalization_mode** | `QUANTILES` | Normalization mode for actions. |
|
||||
| **max_episodes** | all | Max episodes to load. |
|
||||
| **sample_fraction** | `0.2` | Fraction of chunks sampled per episode. |
|
||||
| **max_chunks_for_reconstruction**| `500` | Chunks used for reconstruction and timing. |
|
||||
| **output_dir** | `outputs/action_tokenizer_benchmark` | Directory for results JSON. |
|
||||
|
||||
## Metrics
|
||||
|
||||
**Reconstruction (lower is better)**
|
||||
|
||||
- **reconstruction_mae**: Mean absolute error between original and decoded action chunks.
|
||||
- **reconstruction_mse**: Mean squared error.
|
||||
- **reconstruction_rmse**: Root mean squared error.
|
||||
- **reconstruction_max_abs_error**: Maximum absolute error over all dimensions and samples.
|
||||
- **per_dimension_mae**: MAE per action dimension (list of length `action_dim`).
|
||||
|
||||
**Compression**
|
||||
|
||||
- **compression_ratio**: Ratio (action_horizon × action_dim) / mean number of tokens. Higher means more compression.
|
||||
- **mean_token_length**, **std_token_length**: Mean and standard deviation of token count per chunk.
|
||||
- **min_token_length**, **max_token_length**: Min and max token count.
|
||||
- **p50_token_length**, **p99_token_length**: 50th and 99th percentile token counts.
|
||||
|
||||
**Timing (seconds per chunk)**
|
||||
|
||||
- **mean_encode_time_sec**: Mean time to encode one chunk.
|
||||
- **mean_decode_time_sec**: Mean time to decode one chunk.
|
||||
|
||||
The JSON output also includes **num_chunks_evaluated** and **total_chunks_available** for context.
|
||||
|
||||
## How the benchmark works
|
||||
|
||||
1. **Load dataset**: LeRobot dataset is loaded for the given `repo_id` and `root`.
|
||||
2. **Build action chunks**: For each episode (up to `max_episodes`), action chunks are built with the same logic as `lerobot-train-tokenizer`: sliding window of length `action_horizon`, optional delta transform, and per-episode sampling with `sample_fraction`.
|
||||
3. **Extract and normalize**: Only `encoded_dims` are kept. Normalization is applied using the dataset’s action stats when available, according to `normalization_mode`.
|
||||
4. **Encode / decode**: A random sample of chunks (size `max_chunks_for_reconstruction`) is encoded and then decoded with the tokenizer. Encode and decode times are recorded per chunk.
|
||||
5. **Compute metrics**: Reconstruction metrics are computed between original and decoded chunks; compression and timing stats are aggregated.
|
||||
6. **Save results**: A JSON file is written to `output_dir` with name `{timestamp}_{repo_id}_action_tokenizer_results.json`, containing the full config and all metrics.
|
||||
|
||||
The pipeline (chunking, dimensions, normalization, delta) must match how the tokenizer was trained; otherwise reconstruction error can be large or the tokenizer may raise.
|
||||
|
||||
## Caveats
|
||||
|
||||
- The tokenizer’s **action_horizon** and **action_dim** (and optionally DCT settings) are fixed at training time. The benchmark infers dimensions from the dataset and encoded dims; the tokenizer path must correspond to a model trained with the same horizon and encoded dimensions.
|
||||
- Reconstruction is evaluated in **normalized space** (the same space the tokenizer sees). For interpretation in raw action space, you would need to invert normalization outside this script.
|
||||
- Only one tokenizer and one dataset are evaluated per run. To compare tokenizers or datasets, run the script multiple times and compare the saved JSON files.
|
||||
|
||||
## Example
|
||||
|
||||
Quick run with a local tokenizer and a small number of episodes:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--max-episodes=50 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
|
||||
With delta transform and custom encoded dimensions:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--encoded-dims=0:6,7:14 \
|
||||
--delta-dims=0,1,2,3,4,5 \
|
||||
--use-delta-transform \
|
||||
--normalization-mode=QUANTILES \
|
||||
--max-chunks-for-reconstruction=500 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
|
||||
Results are written to e.g. `outputs/action_tokenizer_benchmark/2026-02-12_14-30-00_lerobot_pusht_action_tokenizer_results.json`.
|
||||
|
||||
## Results
|
||||
|
||||
Results are stored as JSON in the directory given by `--output-dir` (default: `outputs/action_tokenizer_benchmark`). Each file contains:
|
||||
|
||||
- **config**: All script arguments (tokenizer path, repo_id, action_horizon, encoded_dims, normalization_mode, etc.) for reproducibility.
|
||||
- **metrics**: All reconstruction, compression, and timing metrics described above.
|
||||
|
||||
To compare runs, load and diff or aggregate these JSON files with your own scripts or notebooks.
|
||||
@@ -0,0 +1,442 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Benchmark action tokenization: reconstruction error, compression ratio, and timing.
|
||||
|
||||
Loads action chunks from a LeRobot dataset, encodes/decodes them with a trained action
|
||||
tokenizer, and reports:
|
||||
- Reconstruction: MAE, MSE, RMSE, max absolute error, per-dimension MAE
|
||||
- Jerk: mean absolute jerk (original and reconstructed), jerk reconstruction MAE
|
||||
- Compression: ratio (input size / mean tokens), token length stats
|
||||
- Timing: mean encode/decode time per chunk
|
||||
|
||||
Results are saved to outputs/action_tokenizer_benchmark/<timestamp>_results.json.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
|
||||
--action-tokenizer-path=outputs/wavetoken \
|
||||
--repo-id=lerobot/pusht \
|
||||
--action-horizon=10 \
|
||||
--max-episodes=50 \
|
||||
--output-dir=outputs/action_tokenizer_benchmark
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
# Optional: use same helpers as train script if we want to avoid duplication
|
||||
from lerobot.scripts.lerobot_train_tokenizer import (
|
||||
apply_normalization,
|
||||
process_episode,
|
||||
)
|
||||
|
||||
|
||||
def load_action_chunks(
|
||||
repo_id: str,
|
||||
root: str | None,
|
||||
action_horizon: int,
|
||||
max_episodes: int | None,
|
||||
sample_fraction: float,
|
||||
encoded_dims: str,
|
||||
delta_dims: str | None,
|
||||
use_delta_transform: bool,
|
||||
state_key: str,
|
||||
normalization_mode: NormalizationMode,
|
||||
):
|
||||
"""Load and normalize action chunks from a LeRobot dataset (same pipeline as training)."""
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
||||
num_episodes = dataset.num_episodes
|
||||
if max_episodes is not None:
|
||||
num_episodes = min(max_episodes, num_episodes)
|
||||
|
||||
# Parse encoded dims
|
||||
encoded_dim_ranges = []
|
||||
for range_str in encoded_dims.split(","):
|
||||
start, end = map(int, range_str.strip().split(":"))
|
||||
encoded_dim_ranges.append((start, end))
|
||||
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
||||
|
||||
delta_dim_list = None
|
||||
if delta_dims is not None and delta_dims.strip():
|
||||
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")]
|
||||
|
||||
all_chunks = []
|
||||
for ep_idx in range(num_episodes):
|
||||
chunks = process_episode(
|
||||
(
|
||||
dataset,
|
||||
ep_idx,
|
||||
action_horizon,
|
||||
delta_dim_list,
|
||||
sample_fraction,
|
||||
state_key,
|
||||
use_delta_transform,
|
||||
)
|
||||
)
|
||||
if chunks is not None:
|
||||
all_chunks.append(chunks)
|
||||
|
||||
if not all_chunks:
|
||||
raise ValueError("No action chunks collected. Check action_horizon and dataset.")
|
||||
|
||||
all_chunks = np.concatenate(all_chunks, axis=0)
|
||||
|
||||
# Extract encoded dimensions only
|
||||
encoded_chunks = []
|
||||
for start, end in encoded_dim_ranges:
|
||||
encoded_chunks.append(all_chunks[:, :, start:end])
|
||||
encoded_chunks = np.concatenate(encoded_chunks, axis=-1)
|
||||
|
||||
# Normalize
|
||||
norm_stats = dataset.meta.stats
|
||||
if norm_stats is not None and ACTION in norm_stats:
|
||||
action_stats = norm_stats[ACTION]
|
||||
encoded_dim_indices = []
|
||||
for start, end in encoded_dim_ranges:
|
||||
encoded_dim_indices.extend(range(start, end))
|
||||
encoded_dim_indices = np.array(encoded_dim_indices)
|
||||
encoded_stats = {}
|
||||
for stat_name, stat_values in action_stats.items():
|
||||
if isinstance(stat_values, (list, np.ndarray)):
|
||||
stat_array = np.array(stat_values)
|
||||
if len(stat_array) > max(encoded_dim_indices):
|
||||
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
|
||||
if encoded_stats:
|
||||
try:
|
||||
encoded_chunks = apply_normalization(
|
||||
encoded_chunks, encoded_stats, normalization_mode, eps=1e-8
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return encoded_chunks, total_encoded_dims, action_horizon, dataset.repo_id
|
||||
|
||||
|
||||
def compute_reconstruction_metrics(original: np.ndarray, reconstructed: np.ndarray):
|
||||
"""Compute reconstruction error metrics (original and reconstructed same shape [N, T, D])."""
|
||||
diff = reconstructed - original
|
||||
mae = float(np.mean(np.abs(diff)))
|
||||
mse = float(np.mean(diff**2))
|
||||
rmse = float(np.sqrt(mse))
|
||||
max_abs_err = float(np.max(np.abs(diff)))
|
||||
|
||||
# Per-dimension MAE (over N and T)
|
||||
per_dim_mae = np.mean(np.abs(diff), axis=(0, 1))
|
||||
per_dim_mae = per_dim_mae.tolist()
|
||||
|
||||
return {
|
||||
"reconstruction_mae": mae,
|
||||
"reconstruction_mse": mse,
|
||||
"reconstruction_rmse": rmse,
|
||||
"reconstruction_max_abs_error": max_abs_err,
|
||||
"per_dimension_mae": per_dim_mae,
|
||||
}
|
||||
|
||||
|
||||
def compute_jerk_metrics(original: np.ndarray, reconstructed: np.ndarray) -> dict:
|
||||
"""Compute jerk (3rd derivative of action w.r.t. time) metrics.
|
||||
|
||||
Args:
|
||||
original: Action chunks [N, T, D].
|
||||
reconstructed: Reconstructed action chunks [N, T, D].
|
||||
|
||||
Returns:
|
||||
Dict with mean absolute jerk for original, reconstructed, and jerk reconstruction MAE.
|
||||
"""
|
||||
# Jerk = 3rd discrete difference along time axis; need T >= 4
|
||||
if original.shape[1] < 4:
|
||||
return {}
|
||||
jerk_orig = np.diff(original, n=3, axis=1) # (N, T-3, D)
|
||||
jerk_recon = np.diff(reconstructed, n=3, axis=1)
|
||||
mae_jerk_orig = float(np.mean(np.abs(jerk_orig)))
|
||||
mae_jerk_recon = float(np.mean(np.abs(jerk_recon)))
|
||||
jerk_reconstruction_mae = float(np.mean(np.abs(jerk_recon - jerk_orig)))
|
||||
return {
|
||||
"jerk_mae_original": mae_jerk_orig,
|
||||
"jerk_mae_reconstructed": mae_jerk_recon,
|
||||
"jerk_reconstruction_mae": jerk_reconstruction_mae,
|
||||
}
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
action_chunks: np.ndarray,
|
||||
action_horizon: int,
|
||||
action_dim: int,
|
||||
tokenizer_path: str,
|
||||
max_chunks_for_reconstruction: int | None = 500,
|
||||
):
|
||||
"""Encode/decode action chunks and compute metrics."""
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
|
||||
|
||||
n_chunks = len(action_chunks)
|
||||
sample_size = n_chunks
|
||||
if max_chunks_for_reconstruction is not None:
|
||||
sample_size = min(max_chunks_for_reconstruction, n_chunks)
|
||||
rng = np.random.RandomState(42)
|
||||
indices = rng.choice(n_chunks, size=sample_size, replace=False)
|
||||
sample_chunks = action_chunks[indices]
|
||||
|
||||
# Encode
|
||||
token_lengths = []
|
||||
encode_times = []
|
||||
all_tokens = []
|
||||
for i in range(len(sample_chunks)):
|
||||
chunk = sample_chunks[i : i + 1]
|
||||
t0 = time.perf_counter()
|
||||
tokens = processor(chunk)[0]
|
||||
encode_times.append(time.perf_counter() - t0)
|
||||
if isinstance(tokens, list):
|
||||
token_lengths.append(len(tokens))
|
||||
all_tokens.append(tokens)
|
||||
else:
|
||||
n = tokens.shape[0] if hasattr(tokens, "shape") else len(tokens)
|
||||
token_lengths.append(n)
|
||||
all_tokens.append(tokens.tolist() if hasattr(tokens, "tolist") else list(tokens))
|
||||
|
||||
# Decode (processor keeps time_horizon/action_dim from encode)
|
||||
decoded_list = []
|
||||
decode_times = []
|
||||
for i, tok_list in enumerate(all_tokens):
|
||||
t0 = time.perf_counter()
|
||||
recon = processor.decode(
|
||||
[tok_list],
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
decode_times.append(time.perf_counter() - t0)
|
||||
decoded_list.append(recon)
|
||||
decoded = np.concatenate(decoded_list, axis=0)
|
||||
|
||||
# Reconstruction metrics
|
||||
metrics = compute_reconstruction_metrics(sample_chunks, decoded)
|
||||
|
||||
# Jerk metrics (3rd derivative along time)
|
||||
jerk_metrics = compute_jerk_metrics(sample_chunks, decoded)
|
||||
metrics.update(jerk_metrics)
|
||||
|
||||
# Compression
|
||||
token_lengths = np.array(token_lengths)
|
||||
input_size = action_horizon * action_dim
|
||||
compression_ratio = input_size / float(np.mean(token_lengths))
|
||||
metrics["compression_ratio"] = compression_ratio
|
||||
metrics["mean_token_length"] = float(np.mean(token_lengths))
|
||||
metrics["std_token_length"] = float(np.std(token_lengths))
|
||||
metrics["min_token_length"] = int(np.min(token_lengths))
|
||||
metrics["max_token_length"] = int(np.max(token_lengths))
|
||||
metrics["p50_token_length"] = float(np.percentile(token_lengths, 50))
|
||||
metrics["p99_token_length"] = float(np.percentile(token_lengths, 99))
|
||||
|
||||
# Timing (seconds per chunk)
|
||||
metrics["mean_encode_time_sec"] = float(np.mean(encode_times))
|
||||
metrics["mean_decode_time_sec"] = float(np.mean(decode_times))
|
||||
metrics["num_chunks_evaluated"] = sample_size
|
||||
metrics["total_chunks_available"] = n_chunks
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main(
|
||||
action_tokenizer_path: str,
|
||||
repo_id: str,
|
||||
root: str | None = None,
|
||||
action_horizon: int = 10,
|
||||
max_episodes: int | None = 100,
|
||||
sample_fraction: float = 0.2,
|
||||
encoded_dims: str = "0:6",
|
||||
delta_dims: str | None = None,
|
||||
use_delta_transform: bool = False,
|
||||
state_key: str = OBS_STATE,
|
||||
normalization_mode: str = "QUANTILES",
|
||||
max_chunks_for_reconstruction: int | None = 500,
|
||||
output_dir: str | None = None,
|
||||
):
|
||||
if output_dir is None:
|
||||
output_dir = "outputs/action_tokenizer_benchmark"
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
norm_mode = NormalizationMode(normalization_mode)
|
||||
except ValueError:
|
||||
norm_mode = NormalizationMode.QUANTILES
|
||||
|
||||
print("Loading action chunks...")
|
||||
encoded_chunks, action_dim, horizon, _ = load_action_chunks(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
action_horizon=action_horizon,
|
||||
max_episodes=max_episodes,
|
||||
sample_fraction=sample_fraction,
|
||||
encoded_dims=encoded_dims,
|
||||
delta_dims=delta_dims,
|
||||
use_delta_transform=use_delta_transform,
|
||||
state_key=state_key,
|
||||
normalization_mode=norm_mode,
|
||||
)
|
||||
print(f"Loaded {len(encoded_chunks)} chunks, shape {encoded_chunks.shape} (H={horizon}, D={action_dim})")
|
||||
|
||||
print("Running tokenizer benchmark...")
|
||||
metrics = run_benchmark(
|
||||
action_chunks=encoded_chunks,
|
||||
action_horizon=horizon,
|
||||
action_dim=action_dim,
|
||||
tokenizer_path=action_tokenizer_path,
|
||||
max_chunks_for_reconstruction=max_chunks_for_reconstruction,
|
||||
)
|
||||
|
||||
# Attach config for reproducibility
|
||||
results = {
|
||||
"config": {
|
||||
"action_tokenizer_path": action_tokenizer_path,
|
||||
"repo_id": repo_id,
|
||||
"action_horizon": action_horizon,
|
||||
"max_episodes": max_episodes,
|
||||
"sample_fraction": sample_fraction,
|
||||
"encoded_dims": encoded_dims,
|
||||
"delta_dims": delta_dims,
|
||||
"use_delta_transform": use_delta_transform,
|
||||
"state_key": state_key,
|
||||
"normalization_mode": normalization_mode,
|
||||
},
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
safe_repo = repo_id.replace("/", "_")
|
||||
out_file = output_path / f"{timestamp}_{safe_repo}_action_tokenizer_results.json"
|
||||
with open(out_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"Results saved to {out_file}")
|
||||
print("Metrics:")
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, list):
|
||||
print(f" {k}: (length {len(v)})")
|
||||
else:
|
||||
print(f" {k}: {v}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark action tokenization (reconstruction error, compression, timing)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action-tokenizer-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path or HuggingFace repo id of the trained action tokenizer (e.g. outputs/wavetoken).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="LeRobot dataset repo id (e.g. lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Root directory for LeRobot datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action-horizon",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of future steps per action chunk.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-episodes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max episodes to use (default: all).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-fraction",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Fraction of chunks to sample per episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoded-dims",
|
||||
type=str,
|
||||
default="0:6",
|
||||
help="Dimension ranges to encode (e.g. 0:6,7:14).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta-dims",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated dimensions for delta transform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-delta-transform",
|
||||
action="store_true",
|
||||
help="Apply delta (relative) transform to specified dimensions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state-key",
|
||||
type=str,
|
||||
default=OBS_STATE,
|
||||
help="Dataset key for state (for delta transform).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalization-mode",
|
||||
type=str,
|
||||
default="QUANTILES",
|
||||
choices=[m.value for m in NormalizationMode],
|
||||
help="Normalization mode for actions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-chunks-for-reconstruction",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Max chunks to use for reconstruction metrics (default: 500).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="outputs/action_tokenizer_benchmark",
|
||||
help="Directory to save results JSON (default: outputs/action_tokenizer_benchmark).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(
|
||||
action_tokenizer_path=args.action_tokenizer_path,
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
action_horizon=args.action_horizon,
|
||||
max_episodes=args.max_episodes,
|
||||
sample_fraction=args.sample_fraction,
|
||||
encoded_dims=args.encoded_dims,
|
||||
delta_dims=args.delta_dims,
|
||||
use_delta_transform=args.use_delta_transform,
|
||||
state_key=args.state_key,
|
||||
normalization_mode=args.normalization_mode,
|
||||
max_chunks_for_reconstruction=args.max_chunks_for_reconstruction,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
@@ -1,13 +1,15 @@
|
||||
# Installation
|
||||
|
||||
## Install [`miniforge`](https://conda-forge.org/download/)
|
||||
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
|
||||
|
||||
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
|
||||
|
||||
```bash
|
||||
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
|
||||
bash Miniforge3-$(uname)-$(uname -m).sh
|
||||
```
|
||||
|
||||
## Environment Setup
|
||||
## Step 2: Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using conda:
|
||||
|
||||
@@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge
|
||||
>
|
||||
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
## Install LeRobot 🤗
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
|
||||
|
||||
+3
-3
@@ -360,9 +360,9 @@ ignore_errors = false
|
||||
module = "lerobot.cameras.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.motors.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.motors.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.robots.*"
|
||||
|
||||
@@ -13,5 +13,5 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
|
||||
from .utils import make_cameras_from_configs
|
||||
|
||||
@@ -25,6 +25,10 @@ class ColorMode(str, Enum):
|
||||
RGB = "rgb"
|
||||
BGR = "bgr"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
class Cv2Rotation(int, Enum):
|
||||
NO_ROTATION = 0
|
||||
@@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum):
|
||||
ROTATE_180 = 180
|
||||
ROTATE_270 = -90
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
|
||||
class Cv2Backends(int, Enum):
|
||||
ANY = 0
|
||||
V4L2 = 200
|
||||
DSHOW = 700
|
||||
PVAPI = 800
|
||||
ANDROID = 1000
|
||||
AVFOUNDATION = 1200
|
||||
MSMF = 1400
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> None:
|
||||
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
|
||||
|
||||
@@ -32,10 +32,11 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..utils import get_cv2_backend, get_cv2_rotation
|
||||
from ..utils import get_cv2_rotation
|
||||
from .configuration_opencv import ColorMode, OpenCVCameraConfig
|
||||
|
||||
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
|
||||
@@ -117,7 +118,7 @@ class OpenCVCamera(Camera):
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
self.rotation: int | None = get_cv2_rotation(config.rotation)
|
||||
self.backend: int = get_cv2_backend()
|
||||
self.backend: int = config.backend
|
||||
|
||||
if self.height and self.width:
|
||||
self.capture_width, self.capture_height = self.width, self.height
|
||||
@@ -132,6 +133,7 @@ class OpenCVCamera(Camera):
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the OpenCV camera specified in the configuration.
|
||||
@@ -148,8 +150,6 @@ class OpenCVCamera(Camera):
|
||||
ConnectionError: If the specified camera index/path is not found or fails to open.
|
||||
RuntimeError: If the camera opens but fails to apply requested settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
# Use 1 thread for OpenCV operations to avoid potential conflicts or
|
||||
# blocking in multi-threaded applications, especially during data collection.
|
||||
@@ -178,6 +178,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@check_if_not_connected
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
|
||||
@@ -197,8 +198,6 @@ class OpenCVCamera(Camera):
|
||||
to the requested value.
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
|
||||
|
||||
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
|
||||
if self.config.fourcc is not None:
|
||||
@@ -348,6 +347,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -374,9 +374,6 @@ class OpenCVCamera(Camera):
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -490,6 +487,7 @@ class OpenCVCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
@@ -512,8 +510,6 @@ class OpenCVCamera(Camera):
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
@@ -533,6 +529,7 @@ class OpenCVCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
@@ -548,8 +545,6 @@ class OpenCVCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Rotation
|
||||
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
|
||||
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
|
||||
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
|
||||
warmup_s: Time reading frames before returning from connect (in seconds)
|
||||
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
|
||||
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
@@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig):
|
||||
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
|
||||
warmup_s: int = 1
|
||||
fourcc: str | None = None
|
||||
backend: Cv2Backends = Cv2Backends.ANY
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.rotation not in (
|
||||
Cv2Rotation.NO_ROTATION,
|
||||
Cv2Rotation.ROTATE_90,
|
||||
Cv2Rotation.ROTATE_180,
|
||||
Cv2Rotation.ROTATE_270,
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
self.rotation = Cv2Rotation(self.rotation)
|
||||
self.backend = Cv2Backends(self.backend)
|
||||
|
||||
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
|
||||
raise ValueError(
|
||||
|
||||
@@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig):
|
||||
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
|
||||
@@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
|
||||
import cv2 # type: ignore # TODO: add type stubs for OpenCV
|
||||
import numpy as np # type: ignore # TODO: add type stubs for numpy
|
||||
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.import_utils import _reachy2_sdk_available
|
||||
|
||||
if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
@@ -123,6 +124,7 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -136,9 +138,6 @@ class Reachy2Camera(Camera):
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
@@ -184,6 +183,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Same as read()
|
||||
@@ -197,11 +197,10 @@ class Reachy2Camera(Camera):
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
return self.read()
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
@@ -219,8 +218,6 @@ class Reachy2Camera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.latest_frame is None or self.latest_timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any frames yet.")
|
||||
@@ -233,6 +230,7 @@ class Reachy2Camera(Camera):
|
||||
|
||||
return self.latest_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
@@ -240,8 +238,6 @@ class Reachy2Camera(Camera):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
@@ -30,7 +30,8 @@ try:
|
||||
except Exception as e:
|
||||
logging.info(f"Could not import realsense: {e}")
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
@@ -152,6 +153,7 @@ class RealSenseCamera(Camera):
|
||||
"""Checks if the camera pipeline is started and streams are active."""
|
||||
return self.rs_pipeline is not None and self.rs_profile is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the RealSense camera specified in the configuration.
|
||||
@@ -169,8 +171,6 @@ class RealSenseCamera(Camera):
|
||||
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
|
||||
RuntimeError: If the pipeline starts but fails to apply requested settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
self.rs_pipeline = rs.pipeline()
|
||||
rs_config = rs.config()
|
||||
@@ -290,6 +290,7 @@ class RealSenseCamera(Camera):
|
||||
if self.use_depth:
|
||||
rs_config.enable_stream(rs.stream.depth)
|
||||
|
||||
@check_if_not_connected
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""Sets fps, width, and height from device stream if not already configured.
|
||||
|
||||
@@ -299,8 +300,6 @@ class RealSenseCamera(Camera):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If device is not connected.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
|
||||
|
||||
if self.rs_profile is None:
|
||||
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
|
||||
@@ -320,6 +319,7 @@ class RealSenseCamera(Camera):
|
||||
self.width, self.height = actual_width, actual_height
|
||||
self.capture_width, self.capture_height = actual_width, actual_height
|
||||
|
||||
@check_if_not_connected
|
||||
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (depth) synchronously from the camera.
|
||||
@@ -345,9 +345,6 @@ class RealSenseCamera(Camera):
|
||||
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -374,6 +371,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame (color) synchronously from the camera.
|
||||
@@ -403,9 +401,6 @@ class RealSenseCamera(Camera):
|
||||
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -534,6 +529,7 @@ class RealSenseCamera(Camera):
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame data (color) asynchronously.
|
||||
@@ -556,8 +552,6 @@ class RealSenseCamera(Camera):
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread died unexpectedly or another error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
@@ -578,6 +572,7 @@ class RealSenseCamera(Camera):
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
|
||||
@@ -593,8 +588,6 @@ class RealSenseCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig):
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.rotation not in (
|
||||
Cv2Rotation.NO_ROTATION,
|
||||
Cv2Rotation.ROTATE_90,
|
||||
Cv2Rotation.ROTATE_180,
|
||||
Cv2Rotation.ROTATE_270,
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
|
||||
)
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
self.rotation = Cv2Rotation(self.rotation)
|
||||
|
||||
values = (self.fps, self.width, self.height)
|
||||
if any(v is not None for v in values) and any(v is None for v in values):
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import platform
|
||||
from typing import cast
|
||||
|
||||
from lerobot.utils.import_utils import make_device_from_device_class
|
||||
@@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
|
||||
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_cv2_backend() -> int:
|
||||
import cv2
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
|
||||
# elif platform.system() == "Darwin": # macOS
|
||||
# return cv2.CAP_AVFOUNDATION
|
||||
else: # Linux and others
|
||||
return int(cv2.CAP_ANY)
|
||||
|
||||
@@ -34,7 +34,8 @@ import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
@@ -104,6 +105,7 @@ class ZMQCamera(Camera):
|
||||
"""Checks if the ZMQ socket is initialized and connected."""
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""Connect to ZMQ camera server.
|
||||
|
||||
@@ -111,8 +113,6 @@ class ZMQCamera(Camera):
|
||||
warmup (bool): If True, waits for the camera to provide at least one
|
||||
valid frame before returning. Defaults to True.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
logger.info(f"Connecting to {self}...")
|
||||
|
||||
@@ -211,6 +211,7 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
@@ -228,9 +229,6 @@ class ZMQCamera(Camera):
|
||||
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
|
||||
)
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -301,6 +299,7 @@ class ZMQCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
@@ -317,8 +316,6 @@ class ZMQCamera(Camera):
|
||||
TimeoutError: If no frame data becomes available within the specified timeout.
|
||||
RuntimeError: If the background thread is not running.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
@@ -335,6 +332,7 @@ class ZMQCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
|
||||
"""Return the most recent frame captured immediately (Peeking).
|
||||
|
||||
@@ -350,8 +348,6 @@ class ZMQCamera(Camera):
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If the camera is connected but has not captured any frames yet.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
@@ -32,10 +32,7 @@ class ZMQCameraConfig(CameraConfig):
|
||||
warmup_s: int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
self.color_mode = ColorMode(self.color_mode)
|
||||
|
||||
if self.timeout_ms <= 0:
|
||||
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
|
||||
|
||||
@@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
||||
normalization mode to apply.
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
|
||||
@@ -216,16 +216,17 @@ class ImageTransformsConfig:
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
if cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
elif cfg.type == "RandomAffine":
|
||||
return v2.RandomAffine(**cfg.kwargs)
|
||||
else:
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
transform_cls = getattr(v2, cfg.type, None)
|
||||
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
|
||||
return transform_cls(**cfg.kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Transform '{cfg.type}' is not valid. It must be a class in "
|
||||
f"torchvision.transforms.v2 or 'SharpnessJitter'."
|
||||
)
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
|
||||
@@ -205,6 +205,7 @@ class ObservationConfig:
|
||||
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ class LiberoEnv(gym.Env):
|
||||
visualization_height: int = 480,
|
||||
init_states: bool = True,
|
||||
episode_index: int = 0,
|
||||
n_envs: int = 1,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
@@ -145,7 +146,9 @@ class LiberoEnv(gym.Env):
|
||||
self.episode_length = episode_length
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
|
||||
|
||||
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
default_steps = 500
|
||||
@@ -295,7 +298,8 @@ class LiberoEnv(gym.Env):
|
||||
self._env.seed(seed)
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
|
||||
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
|
||||
self.init_state_id += self._reset_stride # Change init_state_id when reset
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
@@ -373,6 +377,7 @@ def _make_env_fns(
|
||||
init_states=init_states,
|
||||
episode_length=episode_length,
|
||||
episode_index=episode_index,
|
||||
n_envs=n_envs,
|
||||
control_mode=control_mode,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -221,7 +221,7 @@ class RangeFinderGUI:
|
||||
|
||||
self.bus = bus
|
||||
self.groups = groups if groups is not None else {"all": list(bus.motors)}
|
||||
self.group_names = list(groups)
|
||||
self.group_names = list(self.groups)
|
||||
self.current_group = self.group_names[0]
|
||||
|
||||
if not bus.is_connected:
|
||||
@@ -230,18 +230,20 @@ class RangeFinderGUI:
|
||||
self.calibration = bus.read_calibration()
|
||||
self.res_table = bus.model_resolution_table
|
||||
self.present_cache = {
|
||||
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
|
||||
m: bus.read("Present_Position", m, normalize=False)
|
||||
for motors in self.groups.values()
|
||||
for m in motors
|
||||
}
|
||||
|
||||
pygame.init()
|
||||
self.font = pygame.font.Font(None, FONT_SIZE)
|
||||
|
||||
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
|
||||
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
|
||||
self.label_pad = label_pad
|
||||
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
|
||||
self.controls_bottom = 10 + SAVE_H
|
||||
self.base_y = self.controls_bottom + TOP_GAP
|
||||
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
|
||||
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
|
||||
|
||||
self.screen = pygame.display.set_mode((width, height))
|
||||
pygame.display.set_caption("Motors range finder")
|
||||
|
||||
@@ -23,6 +23,7 @@ from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
if TYPE_CHECKING or _can_available:
|
||||
@@ -36,7 +37,6 @@ else:
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
@@ -155,6 +155,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
"""Check if the CAN bus is connected."""
|
||||
return self._is_connected and self.canbus is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""
|
||||
Open the CAN bus and initialize communication.
|
||||
@@ -162,10 +163,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Args:
|
||||
handshake: If True, ping all motors to verify they're present
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected."
|
||||
)
|
||||
|
||||
try:
|
||||
# Auto-detect interface type based on port name
|
||||
@@ -211,6 +208,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
logger.info("Starting handshake with motors...")
|
||||
|
||||
# Drain any pending messages
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
while self.canbus.recv(timeout=0.01):
|
||||
pass
|
||||
|
||||
@@ -246,6 +246,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
)
|
||||
logger.info("Handshake successful. All motors ready.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""
|
||||
Close the CAN bus connection.
|
||||
@@ -253,8 +254,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Args:
|
||||
disable_torque: If True, disable torque on all motors before disconnecting
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
|
||||
|
||||
if disable_torque:
|
||||
try:
|
||||
@@ -283,6 +282,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [0xFF] * 7 + [command_byte]
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
self.canbus.send(msg)
|
||||
if msg := self._recv_motor_response(expected_recv_id=recv_id):
|
||||
self._process_response(motor_name, msg)
|
||||
@@ -341,6 +344,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
recv_id = self._get_motor_recv_id(motor)
|
||||
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
|
||||
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
self.canbus.send(msg)
|
||||
return self._recv_motor_response(expected_recv_id=recv_id)
|
||||
|
||||
@@ -356,6 +363,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Returns:
|
||||
CAN message if received, None otherwise
|
||||
"""
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages_seen = []
|
||||
@@ -394,10 +405,13 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Returns:
|
||||
Dictionary mapping recv_id to CAN message
|
||||
"""
|
||||
responses = {}
|
||||
responses: dict[int, can.Message] = {}
|
||||
expected_set = set(expected_recv_ids)
|
||||
start_time = time.time()
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
try:
|
||||
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
|
||||
# 100us poll timeout
|
||||
@@ -461,6 +475,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
motor_name = self._get_motor_name(motor)
|
||||
motor_type = self._motor_types[motor_name]
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
|
||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
|
||||
self.canbus.send(msg)
|
||||
@@ -488,6 +505,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
# Step 1: Send all MIT control commands
|
||||
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
@@ -562,10 +582,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def read(self, data_name: str, motor: str) -> Value:
|
||||
"""Read a value from a single motor. Positions are always in degrees."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Refresh motor to get latest state
|
||||
msg = self._refresh_motor(motor)
|
||||
@@ -595,6 +614,7 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
raise ValueError(f"Unknown data_name: {data_name}")
|
||||
return mapping[data_name]
|
||||
|
||||
@check_if_not_connected
|
||||
def write(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -605,8 +625,6 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
Write a value to a single motor. Positions are always in degrees.
|
||||
Can write 'Goal_Position', 'Kp', or 'Kd'.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
self._gains[motor][data_name.lower()] = float(value)
|
||||
@@ -656,6 +674,10 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
def _batch_refresh(self, motors: list[str]) -> None:
|
||||
"""Internal helper to refresh a list of motors and update cache."""
|
||||
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
|
||||
# Send refresh commands
|
||||
for motor in motors:
|
||||
motor_id = self._get_motor_id(motor)
|
||||
@@ -678,10 +700,12 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
@check_if_not_connected
|
||||
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
|
||||
"""
|
||||
Write values to multiple motors simultaneously. Positions are always in degrees.
|
||||
"""
|
||||
|
||||
if data_name in ("Kp", "Kd"):
|
||||
key = data_name.lower()
|
||||
for motor, val in values.items():
|
||||
@@ -690,6 +714,8 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
elif data_name == "Goal_Position":
|
||||
# Step 1: Send all MIT control commands
|
||||
recv_id_to_motor: dict[int, str] = {}
|
||||
if self.canbus is None:
|
||||
raise RuntimeError("CAN bus is not initialized.")
|
||||
for motor, value_degrees in values.items():
|
||||
motor_id = self._get_motor_id(motor)
|
||||
motor_name = self._get_motor_name(motor)
|
||||
@@ -732,9 +758,9 @@ class DamiaoMotorsBus(MotorsBusBase):
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self,
|
||||
motors: NameOrID | list[NameOrID] | None = None,
|
||||
motors: str | list[str] | None = None,
|
||||
display_values: bool = True,
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
) -> tuple[dict[str, Value], dict[str, Value]]:
|
||||
"""
|
||||
Interactively record the min/max values of each motor in degrees.
|
||||
|
||||
|
||||
@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
for motor, m in self.motors.items():
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=drive_modes[motor],
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
drive_mode=int(drive_modes[motor]),
|
||||
homing_offset=int(offsets[motor]),
|
||||
range_min=int(mins[motor]),
|
||||
range_max=int(maxes[motor]),
|
||||
)
|
||||
|
||||
return calibration
|
||||
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
if cache:
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
|
||||
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
|
||||
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
On Dynamixel Motors:
|
||||
Present_Position = Actual_Position + Homing_Offset
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
half_turn_homings: dict[NameOrID, Value] = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
|
||||
return
|
||||
return None
|
||||
|
||||
return {id_: data[0] for id_, data in data_list.items()}
|
||||
|
||||
@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
# HACK: monkeypatch
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
|
||||
self.port_handler, scs.PortHandler
|
||||
)
|
||||
self.packet_handler = scs.PacketHandler(protocol_version)
|
||||
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
calibration[motor] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
homing_offset=offsets[motor],
|
||||
range_min=mins[motor],
|
||||
range_max=maxes[motor],
|
||||
homing_offset=int(offsets[motor]),
|
||||
range_min=int(mins[motor]),
|
||||
range_max=int(maxes[motor]),
|
||||
)
|
||||
|
||||
return calibration
|
||||
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
On Feetech Motors:
|
||||
Present_Position = Actual_Position - Homing_Offset
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
half_turn_homings: dict[NameOrID, Value] = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
|
||||
return half_turn_homings
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 0, num_retry=num_retry)
|
||||
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
addr, length = get_address(self.model_ctrl_table, model, "Lock")
|
||||
self._write(addr, length, motor, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for motor in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", motor, 1, num_retry=num_retry)
|
||||
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
|
||||
import scservo_sdk as scs
|
||||
|
||||
data_list = {}
|
||||
data_list: dict[int, int] = {}
|
||||
|
||||
status_length = 6
|
||||
|
||||
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
return
|
||||
return None
|
||||
|
||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||
if ids_errors:
|
||||
|
||||
@@ -23,6 +23,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
|
||||
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
|
||||
"""Write values to multiple motors."""
|
||||
pass
|
||||
|
||||
@@ -179,15 +180,16 @@ class Motor:
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool
|
||||
self.baudrate: int
|
||||
self.packet_start_time: float
|
||||
self.packet_timeout: float
|
||||
self.tx_time_per_byte: float
|
||||
self.is_using: bool
|
||||
self.port_name: str
|
||||
self.ser: serial.Serial
|
||||
is_open: bool
|
||||
baudrate: int
|
||||
packet_start_time: float
|
||||
packet_timeout: float
|
||||
tx_time_per_byte: float
|
||||
is_using: bool
|
||||
port_name: str
|
||||
ser: serial.Serial
|
||||
|
||||
def __init__(self, port_name: str) -> None: ...
|
||||
|
||||
def openPort(self): ...
|
||||
def closePort(self): ...
|
||||
@@ -240,19 +242,22 @@ class PacketHandler(Protocol):
|
||||
def regWriteTxRx(self, port, id, address, length, data): ...
|
||||
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
|
||||
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
|
||||
def broadcastPing(self, port): ...
|
||||
|
||||
|
||||
class GroupSyncRead(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.last_result: bool
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
port: str
|
||||
ph: PortHandler
|
||||
start_address: int
|
||||
data_length: int
|
||||
last_result: bool
|
||||
is_param_changed: bool
|
||||
param: list
|
||||
data_dict: dict
|
||||
|
||||
def __init__(
|
||||
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
|
||||
) -> None: ...
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id): ...
|
||||
def removeParam(self, id): ...
|
||||
@@ -265,15 +270,17 @@ class GroupSyncRead(Protocol):
|
||||
|
||||
|
||||
class GroupSyncWrite(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
port: str
|
||||
ph: PortHandler
|
||||
start_address: int
|
||||
data_length: int
|
||||
is_param_changed: bool
|
||||
param: list
|
||||
data_dict: dict
|
||||
|
||||
def __init__(
|
||||
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
|
||||
) -> None: ...
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id, data): ...
|
||||
def removeParam(self, id): ...
|
||||
@@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motor_model(self, motor: NameOrID) -> int:
|
||||
def _get_motor_model(self, motor: NameOrID) -> str:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor].model
|
||||
elif isinstance(motor, int):
|
||||
@@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
|
||||
if motors is None:
|
||||
return list(self.motors)
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors.copy()
|
||||
elif isinstance(motors, int):
|
||||
return [self._id_to_name(motors)]
|
||||
elif isinstance(motors, Sequence):
|
||||
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
|
||||
else:
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
|
||||
if isinstance(values, (int | float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
@@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
"""Enable torque on selected motors.
|
||||
|
||||
Args:
|
||||
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
|
||||
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
|
||||
Defaults to `None`.
|
||||
num_retry (int, optional): Number of additional retry attempts on communication failure.
|
||||
Defaults to 0.
|
||||
"""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self, motors: int | str | list[str] | None = None):
|
||||
def torque_disabled(self, motors: str | list[str] | None = None):
|
||||
"""Context-manager that guarantees torque is re-enabled.
|
||||
|
||||
This helper is useful to temporarily disable torque when configuring motors.
|
||||
@@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
|
||||
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
|
||||
"""Restore factory calibration for the selected motors.
|
||||
|
||||
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
|
||||
The in-memory :pyattr:`calibration` is cleared.
|
||||
|
||||
Args:
|
||||
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
|
||||
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
|
||||
resets every motor.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
for motor in motors:
|
||||
for motor in motor_names:
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
self.write("Homing_Offset", motor, 0, normalize=False)
|
||||
@@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
self.calibration = {}
|
||||
|
||||
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
|
||||
def set_half_turn_homings(
|
||||
self, motors: NameOrID | Sequence[NameOrID] | None = None
|
||||
) -> dict[NameOrID, Value]:
|
||||
"""Centre each motor range around its current position.
|
||||
|
||||
The function computes and writes a homing offset such that the present position becomes exactly one
|
||||
@@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
|
||||
|
||||
Returns:
|
||||
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
|
||||
dict[str, Value]: Mapping *motor name → written homing offset*.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
self.reset_calibration(motors)
|
||||
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
self.reset_calibration(motor_names)
|
||||
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
homing_offsets = self._get_half_turn_homings(actual_positions)
|
||||
for motor, offset in homing_offsets.items():
|
||||
self.write("Homing_Offset", motor, offset)
|
||||
@@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
pass
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[str, Value], dict[str, Value]]:
|
||||
"""Interactively record the min/max encoder values of each motor.
|
||||
|
||||
Move the joints by hand (with torque disabled) while the method streams live positions. Press
|
||||
@@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
display_values (bool, optional): When `True` (default) a live table is printed to the console.
|
||||
|
||||
Returns:
|
||||
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
|
||||
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
|
||||
extreme values observed for each motor.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
motor_names = self._get_motors_list(motors)
|
||||
|
||||
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
|
||||
user_pressed_enter = False
|
||||
while not user_pressed_enter:
|
||||
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
positions = self.sync_read("Present_Position", motor_names, normalize=False)
|
||||
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
|
||||
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
|
||||
|
||||
if display_values:
|
||||
print("\n-------------------------------------------")
|
||||
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
|
||||
for motor in motors:
|
||||
for motor in motor_names:
|
||||
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
|
||||
|
||||
if enter_pressed():
|
||||
@@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
if display_values and not user_pressed_enter:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motors) + 3)
|
||||
move_cursor_up(len(motor_names) + 3)
|
||||
|
||||
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
|
||||
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
|
||||
if same_min_max:
|
||||
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
|
||||
|
||||
@@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
else:
|
||||
return
|
||||
return None
|
||||
if self._is_error(error):
|
||||
if raise_on_error:
|
||||
raise RuntimeError(self.packet_handler.getRxPacketError(error))
|
||||
else:
|
||||
return
|
||||
return None
|
||||
|
||||
return model_number
|
||||
|
||||
@@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
id_value = self._decode_sign(data_name, {id_: value})
|
||||
decoded = self._decode_sign(data_name, {id_: value})
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
id_value = self._normalize(id_value)
|
||||
normalized = self._normalize(decoded)
|
||||
return normalized[id_]
|
||||
|
||||
return id_value[id_]
|
||||
return decoded[id_]
|
||||
|
||||
def _read(
|
||||
self,
|
||||
@@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[int, int]:
|
||||
) -> tuple[int, int, int]:
|
||||
if length == 1:
|
||||
read_fn = self.packet_handler.read1ByteTxRx
|
||||
elif length == 2:
|
||||
@@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
model = self.motors[motor].model
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
int_value = int(value)
|
||||
if normalize and data_name in self.normalized_data:
|
||||
value = self._unnormalize({id_: value})[id_]
|
||||
int_value = self._unnormalize({id_: value})[id_]
|
||||
|
||||
value = self._encode_sign(data_name, {id_: value})[id_]
|
||||
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
|
||||
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _write(
|
||||
self,
|
||||
@@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
motors: NameOrID | Sequence[NameOrID] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
@@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
|
||||
Args:
|
||||
data_name (str): Register name.
|
||||
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
|
||||
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
|
||||
normalize (bool, optional): Normalisation flag. Defaults to `True`.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
|
||||
@@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||
ids_values, _ = self._sync_read(
|
||||
raw_ids_values, _ = self._sync_read(
|
||||
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
ids_values = self._decode_sign(data_name, ids_values)
|
||||
decoded = self._decode_sign(data_name, raw_ids_values)
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
ids_values = self._normalize(ids_values)
|
||||
normalized = self._normalize(decoded)
|
||||
return {self._id_to_name(id_): value for id_, value in normalized.items()}
|
||||
|
||||
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
||||
return {self._id_to_name(id_): value for id_, value in decoded.items()}
|
||||
|
||||
def _sync_read(
|
||||
self,
|
||||
@@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase):
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
raw_ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in raw_ids_values]
|
||||
if self._has_different_ctrl_tables:
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = next(iter(models))
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
|
||||
if normalize and data_name in self.normalized_data:
|
||||
ids_values = self._unnormalize(ids_values)
|
||||
int_ids_values = self._unnormalize(raw_ids_values)
|
||||
|
||||
ids_values = self._encode_sign(data_name, ids_values)
|
||||
int_ids_values = self._encode_sign(data_name, int_ids_values)
|
||||
|
||||
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
|
||||
self._sync_write(
|
||||
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
def _sync_write(
|
||||
self,
|
||||
|
||||
@@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- Either:
|
||||
@@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig):
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
|
||||
@@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -73,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
|
||||
@@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
loss_dict["losses_after_forward"] = losses.clone().mean().item()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
|
||||
@@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
@@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
|
||||
@@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
Those are: `input_features` and `output_features`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
@@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
current step and additional steps going back).
|
||||
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
|
||||
action_chunk_size: Action chunk size of each action prediction token.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
if ft != FeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
@@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# add our new flattened state
|
||||
state_feats[OBS_STATE] = PolicyFeature(
|
||||
key=OBS_STATE,
|
||||
type=FeatureType.STATE,
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
dtype="float32",
|
||||
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
|
||||
)
|
||||
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
new_features[FeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
|
||||
@@ -314,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
@@ -329,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
The modified transition with the penalty added to complementary data.
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return complementary_data
|
||||
return new_transition
|
||||
|
||||
current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
return new_transition
|
||||
|
||||
# Gripper action is a PolicyAction at this stage
|
||||
gripper_action = action[-1].item()
|
||||
@@ -364,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
# Update complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
|
||||
return new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -412,7 +412,10 @@ def make_processors(
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||
|
||||
if kinematics_solver is not None:
|
||||
add_ee_pose = (
|
||||
cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation
|
||||
)
|
||||
if kinematics_solver is not None and add_ee_pose:
|
||||
env_pipeline_steps.append(
|
||||
ForwardKinematicsJointsToEEObservation(
|
||||
kinematics=kinematics_solver,
|
||||
@@ -435,7 +438,12 @@ def make_processors(
|
||||
)
|
||||
|
||||
# Add gripper penalty processor if gripper config exists and enabled
|
||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
||||
# Only add if max_gripper_pos is explicitly configured (required for normalization)
|
||||
if (
|
||||
cfg.processor.gripper is not None
|
||||
and cfg.processor.gripper.use_gripper
|
||||
and cfg.processor.max_gripper_pos is not None
|
||||
):
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessorStep(
|
||||
penalty=cfg.processor.gripper.gripper_penalty,
|
||||
|
||||
@@ -26,8 +26,21 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64
|
||||
) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
|
||||
def _maybe_truncate(tag: str) -> str:
|
||||
"""Truncate tag to max_tag_length characters if required.
|
||||
|
||||
wandb rejects tags longer than 64 characters.
|
||||
See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py
|
||||
"""
|
||||
if len(tag) <= max_tag_length:
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"seed:{cfg.seed}",
|
||||
@@ -36,6 +49,8 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
if truncate_tags:
|
||||
lst = [_maybe_truncate(tag) for tag in lst]
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
@@ -83,7 +98,7 @@ class WandBLogger:
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
tags=cfg_to_group(cfg, return_list=True, truncate_tags=True),
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
|
||||
@@ -19,6 +19,7 @@ from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
@@ -112,6 +113,7 @@ class BiOpenArmFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -133,6 +135,7 @@ class BiOpenArmFollower(Robot):
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
@@ -146,6 +149,7 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
@@ -170,6 +174,7 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -19,6 +19,7 @@ from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_bi_so_follower import BiSOFollowerConfig
|
||||
@@ -96,6 +97,7 @@ class BiSOFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -116,6 +118,7 @@ class BiSOFollower(Robot):
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
@@ -129,6 +132,7 @@ class BiSOFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
# Remove "left_" prefix
|
||||
left_action = {
|
||||
@@ -148,6 +152,7 @@ class BiSOFollower(Robot):
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -119,6 +119,7 @@ class OpenArmFollower(Robot):
|
||||
"""Check if robot is connected."""
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the robot and optionally calibrate.
|
||||
@@ -126,8 +127,6 @@ class OpenArmFollower(Robot):
|
||||
We assume that at connection time, the arms are in a safe rest position,
|
||||
and torque can be safely disabled to run calibration if needed.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
@@ -219,6 +218,7 @@ class OpenArmFollower(Robot):
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Get current observation from robot including position, velocity, and torque.
|
||||
@@ -228,9 +228,6 @@ class OpenArmFollower(Robot):
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
states = self.bus.sync_read_all_states()
|
||||
@@ -253,6 +250,7 @@ class OpenArmFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(
|
||||
self,
|
||||
action: RobotAction,
|
||||
@@ -272,8 +270,6 @@ class OpenArmFollower(Robot):
|
||||
Returns:
|
||||
The action actually sent (potentially clipped)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -333,10 +329,9 @@ class OpenArmFollower(Robot):
|
||||
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
"""Disconnect from robot."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Disconnect CAN bus
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
|
||||
@@ -45,7 +45,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
from lerobot.utils.import_utils import _can_available
|
||||
|
||||
MOTOR_NAMES = {
|
||||
0x01: "joint_1",
|
||||
@@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig):
|
||||
|
||||
@draccus.wrap()
|
||||
def setup_can(cfg: CANSetupConfig):
|
||||
if not is_package_available("can"):
|
||||
if not _can_available:
|
||||
print("Error: python-can not installed. Install with: pip install python-can")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -166,9 +166,9 @@ def apply_normalization(
|
||||
if q01 is None or q99 is None:
|
||||
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
|
||||
denom = np.maximum(q99 - q01, eps)
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q01, q99)
|
||||
return 2.0 * (clipped - q01) / denom - 1.0
|
||||
# No clipping: match training pipeline NormalizerProcessorStep so tokenizer
|
||||
# is fit on the full range of normalized values (including tails outside [-1, 1]).
|
||||
return 2.0 * (data - q01) / denom - 1.0
|
||||
|
||||
if mode == NormalizationMode.QUANTILE10:
|
||||
q10 = stats.get("q10")
|
||||
@@ -176,9 +176,8 @@ def apply_normalization(
|
||||
if q10 is None or q90 is None:
|
||||
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
|
||||
denom = np.maximum(q90 - q10, eps)
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q10, q90)
|
||||
return 2.0 * (clipped - q10) / denom - 1.0
|
||||
# No clipping: match training pipeline NormalizerProcessorStep.
|
||||
return 2.0 * (data - q10) / denom - 1.0
|
||||
|
||||
raise ValueError(f"Unsupported normalization mode: {mode}")
|
||||
|
||||
@@ -306,7 +305,7 @@ def train_fast_tokenizer(
|
||||
|
||||
# download the tokenizer source code (not pretrained weights)
|
||||
# we'll train a new tokenizer on our own data
|
||||
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
base_tokenizer = AutoProcessor.from_pretrained("/fsx/jade_choghari/outputs/libero_tokenizer_wavetoken1", trust_remote_code=True)
|
||||
|
||||
# convert action_chunks array to list of arrays (expected by .fit())
|
||||
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
||||
@@ -320,6 +319,8 @@ def train_fast_tokenizer(
|
||||
vocab_size=vocab_size,
|
||||
time_horizon=action_chunks.shape[1], # action_horizon
|
||||
action_dim=action_chunks.shape[2], # encoded dimensions
|
||||
wavelet="dmey",
|
||||
level=1,
|
||||
)
|
||||
print("✓ Tokenizer training complete!")
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from functools import cached_property
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..openarm_leader import OpenArmLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -88,6 +89,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -109,6 +111,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
|
||||
@@ -126,6 +129,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..so_leader import SOLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -72,6 +72,7 @@ class BiSOLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
@@ -110,6 +111,7 @@ class BiSOLeader(Teleoperator):
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Any
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_leader import OpenArmLeaderConfig
|
||||
@@ -84,6 +84,7 @@ class OpenArmLeader(Teleoperator):
|
||||
"""Check if teleoperator is connected."""
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
Connect to the teleoperator.
|
||||
@@ -91,8 +92,6 @@ class OpenArmLeader(Teleoperator):
|
||||
For manual control, we disable torque after connecting so the
|
||||
arm can be moved by hand.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
@@ -183,6 +182,7 @@ class OpenArmLeader(Teleoperator):
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get current action from the leader arm.
|
||||
@@ -193,8 +193,6 @@ class OpenArmLeader(Teleoperator):
|
||||
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
action_dict: dict[str, Any] = {}
|
||||
|
||||
@@ -214,10 +212,9 @@ class OpenArmLeader(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from teleoperator."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Disconnect CAN bus
|
||||
# For manual control, ensure torque is disabled before disconnecting
|
||||
|
||||
@@ -390,6 +390,30 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
def test_make_transform_from_config_with_v2_resize(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
assert isinstance(tf, v2.Resize)
|
||||
output = tf(img_tensor)
|
||||
assert output.shape[-2:] == (32, 32)
|
||||
|
||||
|
||||
def test_make_transform_from_config_with_v2_identity(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformConfig(type="Identity", kwargs={})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
assert isinstance(tf, v2.Identity)
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_make_transform_from_config_invalid_type():
|
||||
tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={})
|
||||
with pytest.raises(ValueError, match="not valid"):
|
||||
make_transform_from_config(tf_cfg)
|
||||
|
||||
|
||||
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
|
||||
Reference in New Issue
Block a user