Compare commits

..

1 Commits

Author SHA1 Message Date
Pepijn e07bc8fd97 Potential noise injection for data recording: DART 2026-02-04 11:56:33 +01:00
36 changed files with 301 additions and 900 deletions
+3 -5
View File
@@ -101,11 +101,9 @@ jobs:
runs-on:
group: aws-general-8-plus
if: |
github.repository == 'huggingface/lerobot' && (
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
)
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) ||
github.event_name == 'push' ||
github.event_name == 'workflow_dispatch'
outputs:
image_tag: ${{ steps.set_tag.outputs.image_tag }}
env:
-1
View File
@@ -91,7 +91,6 @@ jobs:
name: Build and Push Docker
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME }}
env:
-134
View File
@@ -1,134 +0,0 @@
# Action tokenizer benchmark
## Questions
What is the trade-off between:
- **Compression**: how many tokens are needed to represent an action chunk (e.g. horizon × action_dim floats)?
- **Reconstruction quality**: how well does encode-then-decode preserve the original actions?
- **Speed**: how long does encoding and decoding take per chunk?
How to choose an action tokenizer?
- Which tokenizer architecture (e.g. dct + BPE, DCT + BPE)?
- Which **action horizon** and **encoded dimensions** to use?
- Which **normalization** (QUANTILES, MEAN_STD, MIN_MAX) and **delta transform** (relative vs absolute actions)?
- How do reconstruction error and compression ratio vary across datasets and tokenizer settings?
This benchmark loads action chunks from a LeRobot dataset using the same pipeline as `lerobot-train-tokenizer`, runs a trained action tokenizer in encode/decode mode, and reports reconstruction error, compression stats, and timing. Results are saved as JSON under `outputs/` for comparison and analysis.
## Variables
**Dataset & chunking**
- **repo_id**: LeRobot dataset (e.g. `lerobot/pusht`). Action statistics and normalization are taken from the dataset metadata when available.
- **action_horizon**: Number of future steps per action chunk (must match the tokenizers training).
- **encoded_dims**: Dimension ranges to encode (e.g. `0:6` or `0:6,7:14`). Must match the tokenizer.
- **max_episodes**: Cap on episodes to load (default: all).
- **sample_fraction**: Fraction of chunks to sample per episode (default `0.2`) to keep runtime manageable.
**Transform & normalization**
- **normalization_mode**: `IDENTITY`, `MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`. Should match the tokenizers training.
- **delta_dims**: Comma-separated dimension indices for delta (relative) transform.
- **use_delta_transform**: Whether to convert actions to relative to current state for those dimensions.
- **state_key**: Dataset key for state (e.g. `observation.state`) used when applying delta transform.
**Tokenizer & evaluation**
- **action_tokenizer_path**: Path or HuggingFace repo id of the trained tokenizer (e.g. `outputs/wavetoken`).
- **max_chunks_for_reconstruction**: Max number of chunks to use for reconstruction and timing (default `500`) to limit runtime.
### Main parameters
| parameter | default | description |
| -------------------------------- | ---------------------------- | ------------------------------------------------ |
| **action_tokenizer_path** | (required) | Path or Hub id of the trained action tokenizer. |
| **repo_id** | (required) | LeRobot dataset repo id. |
| **action_horizon** | `10` | Future steps per chunk. |
| **encoded_dims** | `0:6` | Dimension ranges to encode (e.g. `0:6,7:14`). |
| **normalization_mode** | `QUANTILES` | Normalization mode for actions. |
| **max_episodes** | all | Max episodes to load. |
| **sample_fraction** | `0.2` | Fraction of chunks sampled per episode. |
| **max_chunks_for_reconstruction**| `500` | Chunks used for reconstruction and timing. |
| **output_dir** | `outputs/action_tokenizer_benchmark` | Directory for results JSON. |
## Metrics
**Reconstruction (lower is better)**
- **reconstruction_mae**: Mean absolute error between original and decoded action chunks.
- **reconstruction_mse**: Mean squared error.
- **reconstruction_rmse**: Root mean squared error.
- **reconstruction_max_abs_error**: Maximum absolute error over all dimensions and samples.
- **per_dimension_mae**: MAE per action dimension (list of length `action_dim`).
**Compression**
- **compression_ratio**: Ratio (action_horizon × action_dim) / mean number of tokens. Higher means more compression.
- **mean_token_length**, **std_token_length**: Mean and standard deviation of token count per chunk.
- **min_token_length**, **max_token_length**: Min and max token count.
- **p50_token_length**, **p99_token_length**: 50th and 99th percentile token counts.
**Timing (seconds per chunk)**
- **mean_encode_time_sec**: Mean time to encode one chunk.
- **mean_decode_time_sec**: Mean time to decode one chunk.
The JSON output also includes **num_chunks_evaluated** and **total_chunks_available** for context.
## How the benchmark works
1. **Load dataset**: LeRobot dataset is loaded for the given `repo_id` and `root`.
2. **Build action chunks**: For each episode (up to `max_episodes`), action chunks are built with the same logic as `lerobot-train-tokenizer`: sliding window of length `action_horizon`, optional delta transform, and per-episode sampling with `sample_fraction`.
3. **Extract and normalize**: Only `encoded_dims` are kept. Normalization is applied using the datasets action stats when available, according to `normalization_mode`.
4. **Encode / decode**: A random sample of chunks (size `max_chunks_for_reconstruction`) is encoded and then decoded with the tokenizer. Encode and decode times are recorded per chunk.
5. **Compute metrics**: Reconstruction metrics are computed between original and decoded chunks; compression and timing stats are aggregated.
6. **Save results**: A JSON file is written to `output_dir` with name `{timestamp}_{repo_id}_action_tokenizer_results.json`, containing the full config and all metrics.
The pipeline (chunking, dimensions, normalization, delta) must match how the tokenizer was trained; otherwise reconstruction error can be large or the tokenizer may raise.
## Caveats
- The tokenizers **action_horizon** and **action_dim** (and optionally DCT settings) are fixed at training time. The benchmark infers dimensions from the dataset and encoded dims; the tokenizer path must correspond to a model trained with the same horizon and encoded dimensions.
- Reconstruction is evaluated in **normalized space** (the same space the tokenizer sees). For interpretation in raw action space, you would need to invert normalization outside this script.
- Only one tokenizer and one dataset are evaluated per run. To compare tokenizers or datasets, run the script multiple times and compare the saved JSON files.
## Example
Quick run with a local tokenizer and a small number of episodes:
```bash
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
--action-tokenizer-path=outputs/wavetoken \
--repo-id=lerobot/pusht \
--action-horizon=10 \
--max-episodes=50 \
--output-dir=outputs/action_tokenizer_benchmark
```
With delta transform and custom encoded dimensions:
```bash
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
--action-tokenizer-path=outputs/wavetoken \
--repo-id=lerobot/pusht \
--action-horizon=10 \
--encoded-dims=0:6,7:14 \
--delta-dims=0,1,2,3,4,5 \
--use-delta-transform \
--normalization-mode=QUANTILES \
--max-chunks-for-reconstruction=500 \
--output-dir=outputs/action_tokenizer_benchmark
```
Results are written to e.g. `outputs/action_tokenizer_benchmark/2026-02-12_14-30-00_lerobot_pusht_action_tokenizer_results.json`.
## Results
Results are stored as JSON in the directory given by `--output-dir` (default: `outputs/action_tokenizer_benchmark`). Each file contains:
- **config**: All script arguments (tokenizer path, repo_id, action_horizon, encoded_dims, normalization_mode, etc.) for reproducibility.
- **metrics**: All reconstruction, compression, and timing metrics described above.
To compare runs, load and diff or aggregate these JSON files with your own scripts or notebooks.
@@ -1,442 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Benchmark action tokenization: reconstruction error, compression ratio, and timing.
Loads action chunks from a LeRobot dataset, encodes/decodes them with a trained action
tokenizer, and reports:
- Reconstruction: MAE, MSE, RMSE, max absolute error, per-dimension MAE
- Jerk: mean absolute jerk (original and reconstructed), jerk reconstruction MAE
- Compression: ratio (input size / mean tokens), token length stats
- Timing: mean encode/decode time per chunk
Results are saved to outputs/action_tokenizer_benchmark/<timestamp>_results.json.
Example:
```bash
python benchmarks/tokens/run_action_tokenizer_benchmark.py \
--action-tokenizer-path=outputs/wavetoken \
--repo-id=lerobot/pusht \
--action-horizon=10 \
--max-episodes=50 \
--output-dir=outputs/action_tokenizer_benchmark
```
"""
import argparse
import json
import time
from pathlib import Path
import numpy as np
from lerobot.configs.types import NormalizationMode
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_STATE
# Optional: use same helpers as train script if we want to avoid duplication
from lerobot.scripts.lerobot_train_tokenizer import (
apply_normalization,
process_episode,
)
def load_action_chunks(
repo_id: str,
root: str | None,
action_horizon: int,
max_episodes: int | None,
sample_fraction: float,
encoded_dims: str,
delta_dims: str | None,
use_delta_transform: bool,
state_key: str,
normalization_mode: NormalizationMode,
):
"""Load and normalize action chunks from a LeRobot dataset (same pipeline as training)."""
dataset = LeRobotDataset(repo_id=repo_id, root=root)
num_episodes = dataset.num_episodes
if max_episodes is not None:
num_episodes = min(max_episodes, num_episodes)
# Parse encoded dims
encoded_dim_ranges = []
for range_str in encoded_dims.split(","):
start, end = map(int, range_str.strip().split(":"))
encoded_dim_ranges.append((start, end))
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
delta_dim_list = None
if delta_dims is not None and delta_dims.strip():
delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")]
all_chunks = []
for ep_idx in range(num_episodes):
chunks = process_episode(
(
dataset,
ep_idx,
action_horizon,
delta_dim_list,
sample_fraction,
state_key,
use_delta_transform,
)
)
if chunks is not None:
all_chunks.append(chunks)
if not all_chunks:
raise ValueError("No action chunks collected. Check action_horizon and dataset.")
all_chunks = np.concatenate(all_chunks, axis=0)
# Extract encoded dimensions only
encoded_chunks = []
for start, end in encoded_dim_ranges:
encoded_chunks.append(all_chunks[:, :, start:end])
encoded_chunks = np.concatenate(encoded_chunks, axis=-1)
# Normalize
norm_stats = dataset.meta.stats
if norm_stats is not None and ACTION in norm_stats:
action_stats = norm_stats[ACTION]
encoded_dim_indices = []
for start, end in encoded_dim_ranges:
encoded_dim_indices.extend(range(start, end))
encoded_dim_indices = np.array(encoded_dim_indices)
encoded_stats = {}
for stat_name, stat_values in action_stats.items():
if isinstance(stat_values, (list, np.ndarray)):
stat_array = np.array(stat_values)
if len(stat_array) > max(encoded_dim_indices):
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
if encoded_stats:
try:
encoded_chunks = apply_normalization(
encoded_chunks, encoded_stats, normalization_mode, eps=1e-8
)
except ValueError:
pass
return encoded_chunks, total_encoded_dims, action_horizon, dataset.repo_id
def compute_reconstruction_metrics(original: np.ndarray, reconstructed: np.ndarray):
"""Compute reconstruction error metrics (original and reconstructed same shape [N, T, D])."""
diff = reconstructed - original
mae = float(np.mean(np.abs(diff)))
mse = float(np.mean(diff**2))
rmse = float(np.sqrt(mse))
max_abs_err = float(np.max(np.abs(diff)))
# Per-dimension MAE (over N and T)
per_dim_mae = np.mean(np.abs(diff), axis=(0, 1))
per_dim_mae = per_dim_mae.tolist()
return {
"reconstruction_mae": mae,
"reconstruction_mse": mse,
"reconstruction_rmse": rmse,
"reconstruction_max_abs_error": max_abs_err,
"per_dimension_mae": per_dim_mae,
}
def compute_jerk_metrics(original: np.ndarray, reconstructed: np.ndarray) -> dict:
"""Compute jerk (3rd derivative of action w.r.t. time) metrics.
Args:
original: Action chunks [N, T, D].
reconstructed: Reconstructed action chunks [N, T, D].
Returns:
Dict with mean absolute jerk for original, reconstructed, and jerk reconstruction MAE.
"""
# Jerk = 3rd discrete difference along time axis; need T >= 4
if original.shape[1] < 4:
return {}
jerk_orig = np.diff(original, n=3, axis=1) # (N, T-3, D)
jerk_recon = np.diff(reconstructed, n=3, axis=1)
mae_jerk_orig = float(np.mean(np.abs(jerk_orig)))
mae_jerk_recon = float(np.mean(np.abs(jerk_recon)))
jerk_reconstruction_mae = float(np.mean(np.abs(jerk_recon - jerk_orig)))
return {
"jerk_mae_original": mae_jerk_orig,
"jerk_mae_reconstructed": mae_jerk_recon,
"jerk_reconstruction_mae": jerk_reconstruction_mae,
}
def run_benchmark(
action_chunks: np.ndarray,
action_horizon: int,
action_dim: int,
tokenizer_path: str,
max_chunks_for_reconstruction: int | None = 500,
):
"""Encode/decode action chunks and compute metrics."""
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(tokenizer_path, trust_remote_code=True)
n_chunks = len(action_chunks)
sample_size = n_chunks
if max_chunks_for_reconstruction is not None:
sample_size = min(max_chunks_for_reconstruction, n_chunks)
rng = np.random.RandomState(42)
indices = rng.choice(n_chunks, size=sample_size, replace=False)
sample_chunks = action_chunks[indices]
# Encode
token_lengths = []
encode_times = []
all_tokens = []
for i in range(len(sample_chunks)):
chunk = sample_chunks[i : i + 1]
t0 = time.perf_counter()
tokens = processor(chunk)[0]
encode_times.append(time.perf_counter() - t0)
if isinstance(tokens, list):
token_lengths.append(len(tokens))
all_tokens.append(tokens)
else:
n = tokens.shape[0] if hasattr(tokens, "shape") else len(tokens)
token_lengths.append(n)
all_tokens.append(tokens.tolist() if hasattr(tokens, "tolist") else list(tokens))
# Decode (processor keeps time_horizon/action_dim from encode)
decoded_list = []
decode_times = []
for i, tok_list in enumerate(all_tokens):
t0 = time.perf_counter()
recon = processor.decode(
[tok_list],
time_horizon=action_horizon,
action_dim=action_dim,
)
decode_times.append(time.perf_counter() - t0)
decoded_list.append(recon)
decoded = np.concatenate(decoded_list, axis=0)
# Reconstruction metrics
metrics = compute_reconstruction_metrics(sample_chunks, decoded)
# Jerk metrics (3rd derivative along time)
jerk_metrics = compute_jerk_metrics(sample_chunks, decoded)
metrics.update(jerk_metrics)
# Compression
token_lengths = np.array(token_lengths)
input_size = action_horizon * action_dim
compression_ratio = input_size / float(np.mean(token_lengths))
metrics["compression_ratio"] = compression_ratio
metrics["mean_token_length"] = float(np.mean(token_lengths))
metrics["std_token_length"] = float(np.std(token_lengths))
metrics["min_token_length"] = int(np.min(token_lengths))
metrics["max_token_length"] = int(np.max(token_lengths))
metrics["p50_token_length"] = float(np.percentile(token_lengths, 50))
metrics["p99_token_length"] = float(np.percentile(token_lengths, 99))
# Timing (seconds per chunk)
metrics["mean_encode_time_sec"] = float(np.mean(encode_times))
metrics["mean_decode_time_sec"] = float(np.mean(decode_times))
metrics["num_chunks_evaluated"] = sample_size
metrics["total_chunks_available"] = n_chunks
return metrics
def main(
action_tokenizer_path: str,
repo_id: str,
root: str | None = None,
action_horizon: int = 10,
max_episodes: int | None = 100,
sample_fraction: float = 0.2,
encoded_dims: str = "0:6",
delta_dims: str | None = None,
use_delta_transform: bool = False,
state_key: str = OBS_STATE,
normalization_mode: str = "QUANTILES",
max_chunks_for_reconstruction: int | None = 500,
output_dir: str | None = None,
):
if output_dir is None:
output_dir = "outputs/action_tokenizer_benchmark"
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
try:
norm_mode = NormalizationMode(normalization_mode)
except ValueError:
norm_mode = NormalizationMode.QUANTILES
print("Loading action chunks...")
encoded_chunks, action_dim, horizon, _ = load_action_chunks(
repo_id=repo_id,
root=root,
action_horizon=action_horizon,
max_episodes=max_episodes,
sample_fraction=sample_fraction,
encoded_dims=encoded_dims,
delta_dims=delta_dims,
use_delta_transform=use_delta_transform,
state_key=state_key,
normalization_mode=norm_mode,
)
print(f"Loaded {len(encoded_chunks)} chunks, shape {encoded_chunks.shape} (H={horizon}, D={action_dim})")
print("Running tokenizer benchmark...")
metrics = run_benchmark(
action_chunks=encoded_chunks,
action_horizon=horizon,
action_dim=action_dim,
tokenizer_path=action_tokenizer_path,
max_chunks_for_reconstruction=max_chunks_for_reconstruction,
)
# Attach config for reproducibility
results = {
"config": {
"action_tokenizer_path": action_tokenizer_path,
"repo_id": repo_id,
"action_horizon": action_horizon,
"max_episodes": max_episodes,
"sample_fraction": sample_fraction,
"encoded_dims": encoded_dims,
"delta_dims": delta_dims,
"use_delta_transform": use_delta_transform,
"state_key": state_key,
"normalization_mode": normalization_mode,
},
"metrics": metrics,
}
timestamp = time.strftime("%Y-%m-%d_%H-%M-%S")
safe_repo = repo_id.replace("/", "_")
out_file = output_path / f"{timestamp}_{safe_repo}_action_tokenizer_results.json"
with open(out_file, "w") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {out_file}")
print("Metrics:")
for k, v in metrics.items():
if isinstance(v, list):
print(f" {k}: (length {len(v)})")
else:
print(f" {k}: {v}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark action tokenization (reconstruction error, compression, timing)."
)
parser.add_argument(
"--action-tokenizer-path",
type=str,
required=True,
help="Path or HuggingFace repo id of the trained action tokenizer (e.g. outputs/wavetoken).",
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="LeRobot dataset repo id (e.g. lerobot/pusht).",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="Root directory for LeRobot datasets.",
)
parser.add_argument(
"--action-horizon",
type=int,
default=10,
help="Number of future steps per action chunk.",
)
parser.add_argument(
"--max-episodes",
type=int,
default=None,
help="Max episodes to use (default: all).",
)
parser.add_argument(
"--sample-fraction",
type=float,
default=0.2,
help="Fraction of chunks to sample per episode.",
)
parser.add_argument(
"--encoded-dims",
type=str,
default="0:6",
help="Dimension ranges to encode (e.g. 0:6,7:14).",
)
parser.add_argument(
"--delta-dims",
type=str,
default=None,
help="Comma-separated dimensions for delta transform.",
)
parser.add_argument(
"--use-delta-transform",
action="store_true",
help="Apply delta (relative) transform to specified dimensions.",
)
parser.add_argument(
"--state-key",
type=str,
default=OBS_STATE,
help="Dataset key for state (for delta transform).",
)
parser.add_argument(
"--normalization-mode",
type=str,
default="QUANTILES",
choices=[m.value for m in NormalizationMode],
help="Normalization mode for actions.",
)
parser.add_argument(
"--max-chunks-for-reconstruction",
type=int,
default=500,
help="Max chunks to use for reconstruction metrics (default: 500).",
)
parser.add_argument(
"--output-dir",
type=str,
default="outputs/action_tokenizer_benchmark",
help="Directory to save results JSON (default: outputs/action_tokenizer_benchmark).",
)
args = parser.parse_args()
main(
action_tokenizer_path=args.action_tokenizer_path,
repo_id=args.repo_id,
root=args.root,
action_horizon=args.action_horizon,
max_episodes=args.max_episodes,
sample_fraction=args.sample_fraction,
encoded_dims=args.encoded_dims,
delta_dims=args.delta_dims,
use_delta_transform=args.use_delta_transform,
state_key=args.state_key,
normalization_mode=args.normalization_mode,
max_chunks_for_reconstruction=args.max_chunks_for_reconstruction,
output_dir=args.output_dir,
)
+3 -5
View File
@@ -1,15 +1,13 @@
# Installation
This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-).
## Step 1: Install [`miniforge`](https://conda-forge.org/download/)
## Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Step 2: Environment Setup
## Environment Setup
Create a virtual environment with Python 3.10, using conda:
@@ -40,7 +38,7 @@ conda install ffmpeg -c conda-forge
>
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
## Step 3: Install LeRobot 🤗
## Install LeRobot 🤗
### From Source
+3 -3
View File
@@ -360,9 +360,9 @@ ignore_errors = false
module = "lerobot.cameras.*"
ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.motors.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.robots.*"
+1 -1
View File
@@ -13,5 +13,5 @@
# limitations under the License.
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from .configs import CameraConfig, ColorMode, Cv2Rotation
from .utils import make_cameras_from_configs
-23
View File
@@ -25,10 +25,6 @@ class ColorMode(str, Enum):
RGB = "rgb"
BGR = "bgr"
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.")
class Cv2Rotation(int, Enum):
NO_ROTATION = 0
@@ -36,25 +32,6 @@ class Cv2Rotation(int, Enum):
ROTATE_180 = 180
ROTATE_270 = -90
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.")
# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html
class Cv2Backends(int, Enum):
ANY = 0
V4L2 = 200
DSHOW = 700
PVAPI = 800
ANDROID = 1000
AVFOUNDATION = 1200
MSMF = 1400
@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.")
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
+14 -9
View File
@@ -32,11 +32,10 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_rotation
from ..utils import get_cv2_backend, get_cv2_rotation
from .configuration_opencv import ColorMode, OpenCVCameraConfig
# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance,
@@ -118,7 +117,7 @@ class OpenCVCamera(Camera):
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
self.backend: int = config.backend
self.backend: int = get_cv2_backend()
if self.height and self.width:
self.capture_width, self.capture_height = self.width, self.height
@@ -133,7 +132,6 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the OpenCV camera specified in the configuration.
@@ -150,6 +148,8 @@ class OpenCVCamera(Camera):
ConnectionError: If the specified camera index/path is not found or fails to open.
RuntimeError: If the camera opens but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
# Use 1 thread for OpenCV operations to avoid potential conflicts or
# blocking in multi-threaded applications, especially during data collection.
@@ -178,7 +178,6 @@ class OpenCVCamera(Camera):
logger.info(f"{self} connected.")
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
@@ -198,6 +197,8 @@ class OpenCVCamera(Camera):
to the requested value.
DeviceNotConnectedError: If the camera is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
@@ -347,7 +348,6 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -374,6 +374,9 @@ class OpenCVCamera(Camera):
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -487,7 +490,6 @@ class OpenCVCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -510,6 +512,8 @@ class OpenCVCamera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -529,7 +533,6 @@ class OpenCVCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
@@ -545,6 +548,8 @@ class OpenCVCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -15,9 +15,9 @@
from dataclasses import dataclass
from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from ..configs import CameraConfig, ColorMode, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"]
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
@CameraConfig.register_subclass("opencv")
@@ -50,7 +50,6 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY.
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
@@ -63,12 +62,22 @@ class OpenCVCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
backend: Cv2Backends = Cv2Backends.ANY
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
self.backend = Cv2Backends(self.backend)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(
@@ -74,4 +74,7 @@ class Reachy2CameraConfig(CameraConfig):
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
)
self.color_mode = ColorMode(self.color_mode)
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
@@ -32,7 +32,6 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
@@ -124,7 +123,6 @@ class Reachy2Camera(Camera):
"""
raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -138,6 +136,9 @@ class Reachy2Camera(Camera):
"""
start_time = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -183,7 +184,6 @@ class Reachy2Camera(Camera):
return frame
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Same as read()
@@ -197,10 +197,11 @@ class Reachy2Camera(Camera):
TimeoutError: If no frame becomes available within the specified timeout.
RuntimeError: If an unexpected error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
return self.read()
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
@@ -218,6 +219,8 @@ class Reachy2Camera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.latest_frame is None or self.latest_timestamp is None:
raise RuntimeError(f"{self} has not captured any frames yet.")
@@ -230,7 +233,6 @@ class Reachy2Camera(Camera):
return self.latest_frame
@check_if_not_connected
def disconnect(self) -> None:
"""
Stops the background read thread (if running).
@@ -238,6 +240,8 @@ class Reachy2Camera(Camera):
Raises:
DeviceNotConnectedError: If the camera is already disconnected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} not connected.")
if self.cam_manager is not None:
self.cam_manager.disconnect()
@@ -30,8 +30,7 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -153,7 +152,6 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""
Connects to the RealSense camera specified in the configuration.
@@ -171,6 +169,8 @@ class RealSenseCamera(Camera):
ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
RuntimeError: If the pipeline starts but fails to apply requested settings.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
self.rs_pipeline = rs.pipeline()
rs_config = rs.config()
@@ -290,7 +290,6 @@ class RealSenseCamera(Camera):
if self.use_depth:
rs_config.enable_stream(rs.stream.depth)
@check_if_not_connected
def _configure_capture_settings(self) -> None:
"""Sets fps, width, and height from device stream if not already configured.
@@ -300,6 +299,8 @@ class RealSenseCamera(Camera):
Raises:
DeviceNotConnectedError: If device is not connected.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
@@ -319,7 +320,6 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
@check_if_not_connected
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -345,6 +345,9 @@ class RealSenseCamera(Camera):
f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -371,7 +374,6 @@ class RealSenseCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]:
"""
Reads a single frame (color) synchronously from the camera.
@@ -401,6 +403,9 @@ class RealSenseCamera(Camera):
f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -529,7 +534,6 @@ class RealSenseCamera(Camera):
self.new_frame_event.clear()
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame data (color) asynchronously.
@@ -552,6 +556,8 @@ class RealSenseCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread died unexpectedly or another error occurs.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -572,7 +578,6 @@ class RealSenseCamera(Camera):
return frame
# NOTE(Steven): Missing implementation for depth for now
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent (color) frame captured immediately (Peeking).
@@ -588,6 +593,8 @@ class RealSenseCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -60,8 +60,20 @@ class RealSenseCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
self.rotation = Cv2Rotation(self.rotation)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.rotation not in (
Cv2Rotation.NO_ROTATION,
Cv2Rotation.ROTATE_90,
Cv2Rotation.ROTATE_180,
Cv2Rotation.ROTATE_270,
):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
values = (self.fps, self.width, self.height)
if any(v is not None for v in values) and any(v is None for v in values):
+12
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from typing import cast
from lerobot.utils.import_utils import make_device_from_device_class
@@ -67,3 +68,14 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
return None
def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)
+10 -6
View File
@@ -34,8 +34,7 @@ import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
@@ -105,7 +104,6 @@ class ZMQCamera(Camera):
"""Checks if the ZMQ socket is initialized and connected."""
return self._connected and self.context is not None and self.socket is not None
@check_if_already_connected
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server.
@@ -113,6 +111,8 @@ class ZMQCamera(Camera):
warmup (bool): If True, waits for the camera to provide at least one
valid frame before returning. Defaults to True.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
logger.info(f"Connecting to {self}...")
@@ -211,7 +211,6 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Reads a single frame synchronously from the camera.
@@ -229,6 +228,9 @@ class ZMQCamera(Camera):
f"{self} read() color_mode parameter is deprecated and will be removed in future versions."
)
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -299,7 +301,6 @@ class ZMQCamera(Camera):
self.latest_timestamp = None
self.new_frame_event.clear()
@check_if_not_connected
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
"""
Reads the latest available frame asynchronously.
@@ -316,6 +317,8 @@ class ZMQCamera(Camera):
TimeoutError: If no frame data becomes available within the specified timeout.
RuntimeError: If the background thread is not running.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
@@ -332,7 +335,6 @@ class ZMQCamera(Camera):
return frame
@check_if_not_connected
def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]:
"""Return the most recent frame captured immediately (Peeking).
@@ -348,6 +350,8 @@ class ZMQCamera(Camera):
DeviceNotConnectedError: If the camera is not connected.
RuntimeError: If the camera is connected but has not captured any frames yet.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive():
raise RuntimeError(f"{self} read thread is not running.")
+4 -1
View File
@@ -32,7 +32,10 @@ class ZMQCameraConfig(CameraConfig):
warmup_s: int = 1
def __post_init__(self) -> None:
self.color_mode = ColorMode(self.color_mode)
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
+9 -10
View File
@@ -216,17 +216,16 @@ class ImageTransformsConfig:
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "SharpnessJitter":
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
transform_cls = getattr(v2, cfg.type, None)
if isinstance(transform_cls, type) and issubclass(transform_cls, Transform):
return transform_cls(**cfg.kwargs)
raise ValueError(
f"Transform '{cfg.type}' is not valid. It must be a class in "
f"torchvision.transforms.v2 or 'SharpnessJitter'."
)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")
class ImageTransforms(Transform):
+2 -7
View File
@@ -112,7 +112,6 @@ class LiberoEnv(gym.Env):
visualization_height: int = 480,
init_states: bool = True,
episode_index: int = 0,
n_envs: int = 1,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
@@ -146,9 +145,7 @@ class LiberoEnv(gym.Env):
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
@@ -298,8 +295,7 @@ class LiberoEnv(gym.Env):
self._env.seed(seed)
raw_obs = self._env.reset()
if self.init_states and self._init_states is not None:
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
self.init_state_id += self._reset_stride # Change init_state_id when reset
raw_obs = self._env.set_init_state(self._init_states[self._init_state_id])
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles.
@@ -377,7 +373,6 @@ def _make_env_fns(
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
n_envs=n_envs,
control_mode=control_mode,
**local_kwargs,
)
+4 -6
View File
@@ -221,7 +221,7 @@ class RangeFinderGUI:
self.bus = bus
self.groups = groups if groups is not None else {"all": list(bus.motors)}
self.group_names = list(self.groups)
self.group_names = list(groups)
self.current_group = self.group_names[0]
if not bus.is_connected:
@@ -230,20 +230,18 @@ class RangeFinderGUI:
self.calibration = bus.read_calibration()
self.res_table = bus.model_resolution_table
self.present_cache = {
m: bus.read("Present_Position", m, normalize=False)
for motors in self.groups.values()
for m in motors
m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors
}
pygame.init()
self.font = pygame.font.Font(None, FONT_SIZE)
label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms)
label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms)
self.label_pad = label_pad
width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10
self.controls_bottom = 10 + SAVE_H
self.base_y = self.controls_bottom + TOP_GAP
height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40
height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40
self.screen = pygame.display.set_mode((width, height))
pygame.display.set_caption("Motors range finder")
+15 -41
View File
@@ -23,7 +23,6 @@ from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, TypedDict
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _can_available
if TYPE_CHECKING or _can_available:
@@ -37,6 +36,7 @@ else:
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import enter_pressed, move_cursor_up
@@ -155,7 +155,6 @@ class DamiaoMotorsBus(MotorsBusBase):
"""Check if the CAN bus is connected."""
return self._is_connected and self.canbus is not None
@check_if_already_connected
def connect(self, handshake: bool = True) -> None:
"""
Open the CAN bus and initialize communication.
@@ -163,6 +162,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
handshake: If True, ping all motors to verify they're present
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected."
)
try:
# Auto-detect interface type based on port name
@@ -208,9 +211,6 @@ class DamiaoMotorsBus(MotorsBusBase):
logger.info("Starting handshake with motors...")
# Drain any pending messages
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
while self.canbus.recv(timeout=0.01):
pass
@@ -246,7 +246,6 @@ class DamiaoMotorsBus(MotorsBusBase):
)
logger.info("Handshake successful. All motors ready.")
@check_if_not_connected
def disconnect(self, disable_torque: bool = True) -> None:
"""
Close the CAN bus connection.
@@ -254,6 +253,8 @@ class DamiaoMotorsBus(MotorsBusBase):
Args:
disable_torque: If True, disable torque on all motors before disconnecting
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.")
if disable_torque:
try:
@@ -282,10 +283,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor)
data = [0xFF] * 7 + [command_byte]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
if msg := self._recv_motor_response(expected_recv_id=recv_id):
self._process_response(motor_name, msg)
@@ -344,10 +341,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id = self._get_motor_recv_id(motor)
data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0]
msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd)
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
self.canbus.send(msg)
return self._recv_motor_response(expected_recv_id=recv_id)
@@ -363,10 +356,6 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
CAN message if received, None otherwise
"""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
start_time = time.time()
messages_seen = []
@@ -405,13 +394,10 @@ class DamiaoMotorsBus(MotorsBusBase):
Returns:
Dictionary mapping recv_id to CAN message
"""
responses: dict[int, can.Message] = {}
responses = {}
expected_set = set(expected_recv_ids)
start_time = time.time()
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
try:
while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout:
# 100us poll timeout
@@ -475,9 +461,6 @@ class DamiaoMotorsBus(MotorsBusBase):
motor_name = self._get_motor_name(motor)
motor_type = self._motor_types[motor_name]
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque)
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd)
self.canbus.send(msg)
@@ -505,9 +488,6 @@ class DamiaoMotorsBus(MotorsBusBase):
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Step 1: Send all MIT control commands
for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items():
motor_id = self._get_motor_id(motor)
@@ -582,9 +562,10 @@ class DamiaoMotorsBus(MotorsBusBase):
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
@check_if_not_connected
def read(self, data_name: str, motor: str) -> Value:
"""Read a value from a single motor. Positions are always in degrees."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Refresh motor to get latest state
msg = self._refresh_motor(motor)
@@ -614,7 +595,6 @@ class DamiaoMotorsBus(MotorsBusBase):
raise ValueError(f"Unknown data_name: {data_name}")
return mapping[data_name]
@check_if_not_connected
def write(
self,
data_name: str,
@@ -625,6 +605,8 @@ class DamiaoMotorsBus(MotorsBusBase):
Write a value to a single motor. Positions are always in degrees.
Can write 'Goal_Position', 'Kp', or 'Kd'.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if data_name in ("Kp", "Kd"):
self._gains[motor][data_name.lower()] = float(value)
@@ -674,10 +656,6 @@ class DamiaoMotorsBus(MotorsBusBase):
def _batch_refresh(self, motors: list[str]) -> None:
"""Internal helper to refresh a list of motors and update cache."""
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
# Send refresh commands
for motor in motors:
motor_id = self._get_motor_id(motor)
@@ -700,12 +678,10 @@ class DamiaoMotorsBus(MotorsBusBase):
else:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
@check_if_not_connected
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""
Write values to multiple motors simultaneously. Positions are always in degrees.
"""
if data_name in ("Kp", "Kd"):
key = data_name.lower()
for motor, val in values.items():
@@ -714,8 +690,6 @@ class DamiaoMotorsBus(MotorsBusBase):
elif data_name == "Goal_Position":
# Step 1: Send all MIT control commands
recv_id_to_motor: dict[int, str] = {}
if self.canbus is None:
raise RuntimeError("CAN bus is not initialized.")
for motor, value_degrees in values.items():
motor_id = self._get_motor_id(motor)
motor_name = self._get_motor_name(motor)
@@ -758,9 +732,9 @@ class DamiaoMotorsBus(MotorsBusBase):
def record_ranges_of_motion(
self,
motors: str | list[str] | None = None,
motors: NameOrID | list[NameOrID] | None = None,
display_values: bool = True,
) -> tuple[dict[str, Value], dict[str, Value]]:
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""
Interactively record the min/max values of each motor in degrees.
+8 -8
View File
@@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus):
for motor, m in self.motors.items():
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=int(drive_modes[motor]),
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
drive_mode=drive_modes[motor],
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
@@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
if cache:
self.calibration = calibration_dict
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
@@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable")
self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
@@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus):
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings: dict[NameOrID, Value] = {}
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return None
return
return {id_: data[0] for id_, data in data_list.items()}
+9 -9
View File
@@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus):
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
@@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration[motor] = MotorCalibration(
id=m.id,
drive_mode=0,
homing_offset=int(offsets[motor]),
range_min=int(mins[motor]),
range_max=int(maxes[motor]),
homing_offset=offsets[motor],
range_min=mins[motor],
range_max=maxes[motor],
)
return calibration
@@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus):
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings: dict[NameOrID, Value] = {}
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
@@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus):
return half_turn_homings
def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", motor, 0, num_retry=num_retry)
@@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus):
addr, length = get_address(self.model_ctrl_table, model, "Lock")
self._write(addr, length, motor, 0, num_retry=num_retry)
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for motor in self._get_motors_list(motors):
self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", motor, 1, num_retry=num_retry)
@@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus):
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list: dict[int, int] = {}
data_list = {}
status_length = 6
@@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus):
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return None
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
+90 -93
View File
@@ -23,7 +23,6 @@ from __future__ import annotations
import abc
import logging
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -94,7 +93,7 @@ class MotorsBusBase(abc.ABC):
pass
@abc.abstractmethod
def sync_write(self, data_name: str, values: dict[str, Value]) -> None:
def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None:
"""Write values to multiple motors."""
pass
@@ -180,16 +179,15 @@ class Motor:
class PortHandler(Protocol):
is_open: bool
baudrate: int
packet_start_time: float
packet_timeout: float
tx_time_per_byte: float
is_using: bool
port_name: str
ser: serial.Serial
def __init__(self, port_name: str) -> None: ...
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
self.ser: serial.Serial
def openPort(self): ...
def closePort(self): ...
@@ -242,22 +240,19 @@ class PacketHandler(Protocol):
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
def broadcastPing(self, port): ...
class GroupSyncRead(Protocol):
port: str
ph: PortHandler
start_address: int
data_length: int
last_result: bool
is_param_changed: bool
param: list
data_dict: dict
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
@@ -270,17 +265,15 @@ class GroupSyncRead(Protocol):
class GroupSyncWrite(Protocol):
port: str
ph: PortHandler
start_address: int
data_length: int
is_param_changed: bool
param: list
data_dict: dict
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def __init__(
self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int
) -> None: ...
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
@@ -407,7 +400,7 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> str:
def _get_motor_model(self, motor: NameOrID) -> int:
if isinstance(motor, str):
return self.motors[motor].model
elif isinstance(motor, int):
@@ -415,19 +408,17 @@ class SerialMotorsBus(MotorsBusBase):
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]:
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
if motors is None:
return list(self.motors)
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, int):
return [self._id_to_name(motors)]
elif isinstance(motors, Sequence):
return [m if isinstance(m, str) else self._id_to_name(m) for m in motors]
elif isinstance(motors, list):
return motors.copy()
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]:
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
@@ -649,19 +640,18 @@ class SerialMotorsBus(MotorsBusBase):
pass
@abc.abstractmethod
def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None:
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
"""Enable torque on selected motors.
Args:
motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`.
Defaults to `None`.
motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`.
num_retry (int, optional): Number of additional retry attempts on communication failure.
Defaults to 0.
"""
pass
@contextmanager
def torque_disabled(self, motors: str | list[str] | None = None):
def torque_disabled(self, motors: int | str | list[str] | None = None):
"""Context-manager that guarantees torque is re-enabled.
This helper is useful to temporarily disable torque when configuring motors.
@@ -738,19 +728,24 @@ class SerialMotorsBus(MotorsBusBase):
"""
pass
def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None:
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
"""Restore factory calibration for the selected motors.
Homing offset is set to ``0`` and min/max position limits are set to the full usable range.
The in-memory :pyattr:`calibration` is cleared.
Args:
motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default)
motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default)
resets every motor.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
for motor in motor_names:
for motor in motors:
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False)
@@ -759,9 +754,7 @@ class SerialMotorsBus(MotorsBusBase):
self.calibration = {}
def set_half_turn_homings(
self, motors: NameOrID | Sequence[NameOrID] | None = None
) -> dict[NameOrID, Value]:
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
"""Centre each motor range around its current position.
The function computes and writes a homing offset such that the present position becomes exactly one
@@ -771,12 +764,17 @@ class SerialMotorsBus(MotorsBusBase):
motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`).
Returns:
dict[str, Value]: Mapping *motor name → written homing offset*.
dict[NameOrID, Value]: Mapping *motor → written homing offset*.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
self.reset_calibration(motor_names)
actual_positions = self.sync_read("Present_Position", motor_names, normalize=False)
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
@@ -788,8 +786,8 @@ class SerialMotorsBus(MotorsBusBase):
pass
def record_ranges_of_motion(
self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[str, Value], dict[str, Value]]:
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""Interactively record the min/max encoder values of each motor.
Move the joints by hand (with torque disabled) while the method streams live positions. Press
@@ -801,25 +799,30 @@ class SerialMotorsBus(MotorsBusBase):
display_values (bool, optional): When `True` (default) a live table is printed to the console.
Returns:
tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the
tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the
extreme values observed for each motor.
"""
motor_names = self._get_motors_list(motors)
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
start_positions = self.sync_read("Present_Position", motor_names, normalize=False)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
user_pressed_enter = False
while not user_pressed_enter:
positions = self.sync_read("Present_Position", motor_names, normalize=False)
positions = self.sync_read("Present_Position", motors, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for motor in motor_names:
for motor in motors:
print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}")
if enter_pressed():
@@ -827,9 +830,9 @@ class SerialMotorsBus(MotorsBusBase):
if display_values and not user_pressed_enter:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motor_names) + 3)
move_cursor_up(len(motors) + 3)
same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]]
same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]]
if same_min_max:
raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}")
@@ -952,12 +955,12 @@ class SerialMotorsBus(MotorsBusBase):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else:
return None
return
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return None
return
return model_number
@@ -1004,13 +1007,12 @@ class SerialMotorsBus(MotorsBusBase):
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
decoded = self._decode_sign(data_name, {id_: value})
id_value = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data:
normalized = self._normalize(decoded)
return normalized[id_]
id_value = self._normalize(id_value)
return decoded[id_]
return id_value[id_]
def _read(
self,
@@ -1021,7 +1023,7 @@ class SerialMotorsBus(MotorsBusBase):
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int, int]:
) -> tuple[int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
@@ -1071,14 +1073,13 @@ class SerialMotorsBus(MotorsBusBase):
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_value = int(value)
if normalize and data_name in self.normalized_data:
int_value = self._unnormalize({id_: value})[id_]
value = self._unnormalize({id_: value})[id_]
int_value = self._encode_sign(data_name, {id_: int_value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries."
self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self,
@@ -1112,7 +1113,7 @@ class SerialMotorsBus(MotorsBusBase):
def sync_read(
self,
data_name: str,
motors: NameOrID | Sequence[NameOrID] | None = None,
motors: str | list[str] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
@@ -1121,7 +1122,7 @@ class SerialMotorsBus(MotorsBusBase):
Args:
data_name (str): Register name.
motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor.
motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor.
normalize (bool, optional): Normalisation flag. Defaults to `True`.
num_retry (int, optional): Retry attempts. Defaults to `0`.
@@ -1142,17 +1143,16 @@ class SerialMotorsBus(MotorsBusBase):
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
raw_ids_values, _ = self._sync_read(
ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
decoded = self._decode_sign(data_name, raw_ids_values)
ids_values = self._decode_sign(data_name, ids_values)
if normalize and data_name in self.normalized_data:
normalized = self._normalize(decoded)
return {self._id_to_name(id_): value for id_, value in normalized.items()}
ids_values = self._normalize(ids_values)
return {self._id_to_name(id_): value for id_, value in decoded.items()}
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
def _sync_read(
self,
@@ -1224,24 +1224,21 @@ class SerialMotorsBus(MotorsBusBase):
num_retry (int, optional): Retry attempts. Defaults to `0`.
"""
raw_ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in raw_ids_values]
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()}
if normalize and data_name in self.normalized_data:
int_ids_values = self._unnormalize(raw_ids_values)
ids_values = self._unnormalize(ids_values)
int_ids_values = self._encode_sign(data_name, int_ids_values)
ids_values = self._encode_sign(data_name, ids_values)
err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries."
self._sync_write(
addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _sync_write(
self,
@@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone().mean().item()
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item()
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
loss_dict["losses_after_rm_padding"] = losses.clone()
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
+6 -4
View File
@@ -17,7 +17,7 @@ from dataclasses import dataclass
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep):
# copy over non-STATE features
for ft, feats in features.items():
if ft != FeatureType.STATE:
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
@@ -100,11 +100,13 @@ class LiberoProcessorStep(ObservationProcessorStep):
# add our new flattened state
state_feats[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
key=OBS_STATE,
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[FeatureType.STATE] = state_feats
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
@@ -19,7 +19,6 @@ from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
@@ -113,7 +112,6 @@ class BiOpenArmFollower(Robot):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -135,7 +133,6 @@ class BiOpenArmFollower(Robot):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
@@ -149,7 +146,6 @@ class BiOpenArmFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(
self,
action: RobotAction,
@@ -174,7 +170,6 @@ class BiOpenArmFollower(Robot):
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -19,7 +19,6 @@ from functools import cached_property
from lerobot.processor import RobotAction, RobotObservation
from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from .config_bi_so_follower import BiSOFollowerConfig
@@ -97,7 +96,6 @@ class BiSOFollower(Robot):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -118,7 +116,6 @@ class BiSOFollower(Robot):
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
@@ -132,7 +129,6 @@ class BiSOFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(self, action: RobotAction) -> RobotAction:
# Remove "left_" prefix
left_action = {
@@ -152,7 +148,6 @@ class BiSOFollower(Robot):
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.processor import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from ..utils import ensure_safe_goal_position
@@ -119,7 +119,6 @@ class OpenArmFollower(Robot):
"""Check if robot is connected."""
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the robot and optionally calibrate.
@@ -127,6 +126,8 @@ class OpenArmFollower(Robot):
We assume that at connection time, the arms are in a safe rest position,
and torque can be safely disabled to run calibration if needed.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
@@ -218,7 +219,6 @@ class OpenArmFollower(Robot):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_observation(self) -> RobotObservation:
"""
Get current observation from robot including position, velocity, and torque.
@@ -228,6 +228,9 @@ class OpenArmFollower(Robot):
"""
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict: dict[str, Any] = {}
states = self.bus.sync_read_all_states()
@@ -250,7 +253,6 @@ class OpenArmFollower(Robot):
return obs_dict
@check_if_not_connected
def send_action(
self,
action: RobotAction,
@@ -270,6 +272,8 @@ class OpenArmFollower(Robot):
Returns:
The action actually sent (potentially clipped)
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
@@ -329,9 +333,10 @@ class OpenArmFollower(Robot):
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
@check_if_not_connected
def disconnect(self):
"""Disconnect from robot."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect CAN bus
self.bus.disconnect(self.config.disable_torque_on_disconnect)
+20 -6
View File
@@ -184,6 +184,9 @@ class DatasetRecordConfig:
vcodec: str = "libsvtav1"
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
# Expert noise injection scale. Noise is added to robot actions but not recorded in dataset.
# This forces recovery behavior for more robust learned policies. 0.0 means no noise. #https://arxiv.org/pdf/1703.09327, https://arxiv.org/abs/2507.09061
noise_scale: float = 0.0
def __post_init__(self):
if self.single_task is None:
@@ -283,6 +286,7 @@ def record_loop(
single_task: str | None = None,
display_data: bool = False,
display_compressed_images: bool = False,
noise_scale: float = 0.0,
):
if dataset is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
@@ -380,18 +384,27 @@ def record_loop(
action_values = act_processed_teleop
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
# Write clean action to dataset (before noise injection)
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)
# Expert noise injection: add noise to motor commands but not to recorded labels
if noise_scale > 0:
import torch
for key in robot_action_to_send:
if isinstance(robot_action_to_send[key], torch.Tensor):
noise = torch.randn_like(robot_action_to_send[key]) * noise_scale
robot_action_to_send[key] = robot_action_to_send[key] + noise
# Send action to robot
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
_sent_action = robot.send_action(robot_action_to_send)
# Write to dataset
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)
if display_data:
log_rerun_data(
observation=obs_processed, action=action_values, compress_images=display_compressed_images
@@ -510,6 +523,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
display_compressed_images=display_compressed_images,
noise_scale=cfg.dataset.noise_scale,
)
# Execute a few seconds without recording to give time to manually reset the environment
+2 -2
View File
@@ -45,7 +45,7 @@ from dataclasses import dataclass, field
import draccus
from lerobot.utils.import_utils import _can_available
from lerobot.utils.import_utils import is_package_available
MOTOR_NAMES = {
0x01: "joint_1",
@@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig):
@draccus.wrap()
def setup_can(cfg: CANSetupConfig):
if not _can_available:
if not is_package_available("can"):
print("Error: python-can not installed. Install with: pip install python-can")
sys.exit(1)
@@ -166,9 +166,9 @@ def apply_normalization(
if q01 is None or q99 is None:
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
denom = np.maximum(q99 - q01, eps)
# No clipping: match training pipeline NormalizerProcessorStep so tokenizer
# is fit on the full range of normalized values (including tails outside [-1, 1]).
return 2.0 * (data - q01) / denom - 1.0
# Clip to quantile range then normalize to [-1, 1]
clipped = np.clip(data, q01, q99)
return 2.0 * (clipped - q01) / denom - 1.0
if mode == NormalizationMode.QUANTILE10:
q10 = stats.get("q10")
@@ -176,8 +176,9 @@ def apply_normalization(
if q10 is None or q90 is None:
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
denom = np.maximum(q90 - q10, eps)
# No clipping: match training pipeline NormalizerProcessorStep.
return 2.0 * (data - q10) / denom - 1.0
# Clip to quantile range then normalize to [-1, 1]
clipped = np.clip(data, q10, q90)
return 2.0 * (clipped - q10) / denom - 1.0
raise ValueError(f"Unsupported normalization mode: {mode}")
@@ -305,7 +306,7 @@ def train_fast_tokenizer(
# download the tokenizer source code (not pretrained weights)
# we'll train a new tokenizer on our own data
base_tokenizer = AutoProcessor.from_pretrained("/fsx/jade_choghari/outputs/libero_tokenizer_wavetoken1", trust_remote_code=True)
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# convert action_chunks array to list of arrays (expected by .fit())
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
@@ -319,8 +320,6 @@ def train_fast_tokenizer(
vocab_size=vocab_size,
time_horizon=action_chunks.shape[1], # action_horizon
action_dim=action_chunks.shape[2], # encoded dimensions
wavelet="dmey",
level=1,
)
print("✓ Tokenizer training complete!")
@@ -19,7 +19,6 @@ from functools import cached_property
from lerobot.processor import RobotAction
from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_leader import OpenArmLeader
from ..teleoperator import Teleoperator
@@ -89,7 +88,6 @@ class BiOpenArmLeader(Teleoperator):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -111,7 +109,6 @@ class BiOpenArmLeader(Teleoperator):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_action(self) -> RobotAction:
action_dict = {}
@@ -129,7 +126,6 @@ class BiOpenArmLeader(Teleoperator):
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -18,7 +18,7 @@ import logging
from functools import cached_property
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.decorators import check_if_not_connected
from ..so_leader import SOLeader
from ..teleoperator import Teleoperator
@@ -72,7 +72,6 @@ class BiSOLeader(Teleoperator):
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@@ -111,7 +110,6 @@ class BiSOLeader(Teleoperator):
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -21,7 +21,7 @@ from typing import Any
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.processor import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator
from .config_openarm_leader import OpenArmLeaderConfig
@@ -84,7 +84,6 @@ class OpenArmLeader(Teleoperator):
"""Check if teleoperator is connected."""
return self.bus.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
"""
Connect to the teleoperator.
@@ -92,6 +91,8 @@ class OpenArmLeader(Teleoperator):
For manual control, we disable torque after connecting so the
arm can be moved by hand.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
# Connect to CAN bus
logger.info(f"Connecting arm on {self.config.port}...")
@@ -182,7 +183,6 @@ class OpenArmLeader(Teleoperator):
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
)
@check_if_not_connected
def get_action(self) -> RobotAction:
"""
Get current action from the leader arm.
@@ -193,6 +193,8 @@ class OpenArmLeader(Teleoperator):
Reads all motor states (pos/vel/torque) in one CAN refresh cycle.
"""
start = time.perf_counter()
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
action_dict: dict[str, Any] = {}
@@ -212,9 +214,10 @@ class OpenArmLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.")
@check_if_not_connected
def disconnect(self) -> None:
"""Disconnect from teleoperator."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
# Disconnect CAN bus
# For manual control, ensure torque is disabled before disconnecting
-24
View File
@@ -390,30 +390,6 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1))
def test_make_transform_from_config_with_v2_resize(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)})
tf = make_transform_from_config(tf_cfg)
assert isinstance(tf, v2.Resize)
output = tf(img_tensor)
assert output.shape[-2:] == (32, 32)
def test_make_transform_from_config_with_v2_identity(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformConfig(type="Identity", kwargs={})
tf = make_transform_from_config(tf_cfg)
assert isinstance(tf, v2.Identity)
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_make_transform_from_config_invalid_type():
tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={})
with pytest.raises(ValueError, match="not valid"):
make_transform_from_config(tf_cfg)
def test_save_all_transforms(img_tensor_factory, tmp_path):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(enable=True)