Compare commits

..

5 Commits

Author SHA1 Message Date
Maxime Ellerbach 2e0deff3ab fixing final upload to hub 2026-06-17 09:42:05 +00:00
Maxime Ellerbach b42d124007 cleanup 2026-06-15 14:50:23 +00:00
Maxime Ellerbach 3ce50c3468 adding a test for the fsdp checkpoint path 2026-06-15 14:36:22 +00:00
Maxime Ellerbach 44fd3c0a0e adding docs for FSDP 2026-06-15 14:15:09 +00:00
Maxime Ellerbach 0483afc743 feat(train): FSDP checkpoint saving 2026-06-15 14:03:17 +00:00
51 changed files with 1497 additions and 1671 deletions
-1
View File
@@ -136,7 +136,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. - **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot. - **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
- **[LeLab](https://github.com/huggingface/leLab):** A web interface for LeRobot — teleoperate, calibrate, record datasets, replay, and train your SO arm from the browser, no CLI required.
## Citation ## Citation
+8 -8
View File
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
**Compatible teleoperators:** **Compatible teleoperators:**
- `bi_openarm_mini` - Bimanual OpenArm Mini - `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm - `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT] > [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`. > The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags. > `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
--- ---
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \ --robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \ --robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \ --robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \ --teleop.type=openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \ --teleop.port_left=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \ --teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_dataset \ --dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \ --dataset.single_task="Fold the T-shirt properly" \
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \ --robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \ --robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \ --robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \ --teleop.type=openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \ --teleop.port_left=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \ --teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \ --dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \ --dataset.single_task="Fold the T-shirt properly" \
+1 -1
View File
@@ -117,7 +117,7 @@ lerobot-rollout \
--strategy.num_episodes=20 \ --strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \ --robot.type=bi_openarm_follower \
--teleop.type=bi_openarm_mini \ --teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \ --dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt" --dataset.single_task="Fold the T-shirt"
``` ```
+46
View File
@@ -113,6 +113,52 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
--policy=act --policy=act
``` ```
## Training Large Models with FSDP
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
```bash
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=<your_policy> \
--output_dir=outputs/train/my_policy_fsdp
```
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: bf16
num_machines: 1
num_processes: 4
fsdp_config:
fsdp_version: 1
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
fsdp_state_dict_type: FULL_STATE_DICT
```
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
optimizer before `accelerator.prepare()`.
### FSDP checkpoints
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two thigs to look out for:
- With mixed precision, (`bf16`/`fp16`) FSDP keeps an fp32 master copy, so the checkpoint is fp32
(~2× the bf16 size on disk) and is cast back to the policy dtype on load.
- **Optimizer state is not saved under FSDP**, so **resume-from-checkpoint is not supported**.
Saved weights are fully usable for evaluation and fine-tuning.
## Notes ## Notes
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration. - The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
@@ -54,7 +54,6 @@ from typing import Any
import pyarrow as pa import pyarrow as pa
import pyarrow.parquet as pq import pyarrow.parquet as pq
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
from lerobot.datasets.language import ( from lerobot.datasets.language import (
EVENT_ONLY_STYLES, EVENT_ONLY_STYLES,
LANGUAGE_EVENTS, LANGUAGE_EVENTS,
@@ -275,11 +274,12 @@ class LanguageColumnsWriter:
new_table = self._materialize_table( new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
) )
# Re-emit one row group per episode (a bulk pq.write_table would collapse # Atomic replace: write to a sibling tmp path and rename so a crash
# them into one). Write to a sibling tmp path and atomically rename so a # mid-write can't leave a half-written shard that ``pq.read_table``
# crash mid-write can't leave a half-written shard. # would then fail to open. ``Path.replace`` is atomic on POSIX +
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp") tmp_path = path.with_suffix(path.suffix + ".tmp")
write_table_one_row_group_per_episode(new_table, tmp_path) pq.write_table(new_table, tmp_path)
tmp_path.replace(path) tmp_path.replace(path)
def _materialize_table( def _materialize_table(
+2 -3
View File
@@ -442,12 +442,11 @@ class OpenCVCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues. Stops on DeviceNotConnectedError, logs other errors and continues.
""" """
stop_event = self.stop_event if self.stop_event is None:
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0 failure_count = 0
while not stop_event.is_set(): while not self.stop_event.is_set():
try: try:
raw_frame = self._read_from_hardware() raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame) processed_frame = self._postprocess_image(raw_frame)
@@ -471,12 +471,11 @@ class RealSenseCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues. Stops on DeviceNotConnectedError, logs other errors and continues.
""" """
stop_event = self.stop_event if self.stop_event is None:
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0 failure_count = 0
while not stop_event.is_set(): while not self.stop_event.is_set():
try: try:
frame = self._read_from_hardware() frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame() color_frame_raw = frame.get_color_frame()
+2 -3
View File
@@ -246,12 +246,11 @@ class ZMQCamera(Camera):
""" """
Internal loop run by the background thread for asynchronous reading. Internal loop run by the background thread for asynchronous reading.
""" """
stop_event = self.stop_event if self.stop_event is None:
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.") raise RuntimeError(f"{self}: stop_event is not initialized.")
failure_count = 0 failure_count = 0
while not stop_event.is_set(): while not self.stop_event.is_set():
try: try:
frame = self._read_from_hardware() frame = self._read_from_hardware()
capture_time = time.perf_counter() capture_time = time.perf_counter()
+7 -1
View File
@@ -98,6 +98,7 @@ def save_checkpoint(
postprocessor: PolicyProcessorPipeline | None = None, postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None, num_processes: int | None = None,
batch_size: int | None = None, batch_size: int | None = None,
model_state_dict: dict | None = None,
) -> None: ) -> None:
"""This function creates the following directory structure: """This function creates the following directory structure:
@@ -127,9 +128,14 @@ def save_checkpoint(
resume. Defaults to None (not recorded). resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded). resume. Defaults to None (not recorded).
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
cross-rank collective and passes it here so rank 0 can write it directly. It holds
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
""" """
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir) policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
cfg.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir)
if cfg.peft is not None: if cfg.peft is not None:
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the # When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
-9
View File
@@ -32,7 +32,6 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
from .io_utils import ( from .io_utils import (
get_file_size_in_mb, get_file_size_in_mb,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
to_parquet_one_row_group_per_episode,
to_parquet_with_hf_images, to_parquet_with_hf_images,
write_info, write_info,
write_stats, write_stats,
@@ -552,7 +551,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
aggr_root=dst_meta.root, aggr_root=dst_meta.root,
hf_features=hf_features, hf_features=hf_features,
concatenate=concatenate_data, concatenate=concatenate_data,
one_row_group_per_episode=True,
) )
# Record the mapping from source to actual destination # Record the mapping from source to actual destination
@@ -630,7 +628,6 @@ def append_or_create_parquet_file(
aggr_root: Path = None, aggr_root: Path = None,
hf_features: datasets.Features | None = None, hf_features: datasets.Features | None = None,
concatenate: bool = True, concatenate: bool = True,
one_row_group_per_episode: bool = False,
) -> tuple[dict[str, int], tuple[int, int]]: ) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints. """Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -648,8 +645,6 @@ def append_or_create_parquet_file(
aggr_root: Root path for the aggregated dataset. aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing. hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one. concatenate: When False, always rotate to a new file instead of appending to the current one.
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
the episodes-metadata parquet (already one row per episode).
Returns: Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -662,8 +657,6 @@ def append_or_create_parquet_file(
dst_path.parent.mkdir(parents=True, exist_ok=True) dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images: if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features) to_parquet_with_hf_images(df, dst_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(df, dst_path)
else: else:
df.to_parquet(dst_path) df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file) return idx, (dst_chunk, dst_file)
@@ -690,8 +683,6 @@ def append_or_create_parquet_file(
if contains_images: if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features) to_parquet_with_hf_images(final_df, target_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(final_df, target_path)
else: else:
final_df.to_parquet(target_path) final_df.to_parquet(target_path)
+1 -2
View File
@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -710,7 +709,7 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False) obj.root.mkdir(parents=True, exist_ok=False)
features = {**deepcopy(features), **DEFAULT_FEATURES} features = {**features, **DEFAULT_FEATURES}
_validate_feature_names(features) _validate_feature_names(features)
obj.tasks = None obj.tasks = None
-12
View File
@@ -74,8 +74,6 @@ class DatasetReader:
self.episodes = episodes self.episodes = episodes
self._tolerance_s = tolerance_s self._tolerance_s = tolerance_s
self._video_backend = video_backend self._video_backend = video_backend
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms self._image_transforms = image_transforms
self._return_uint8 = return_uint8 self._return_uint8 = return_uint8
@@ -88,16 +86,6 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms
def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self._image_transforms = None
def try_load(self) -> bool: def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient.""" """Attempt to load from local cache. Returns True if data is sufficient."""
try: try:
+1 -4
View File
@@ -27,7 +27,6 @@ import logging
import shutil import shutil
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from copy import deepcopy
from pathlib import Path from pathlib import Path
import datasets import datasets
@@ -1102,9 +1101,7 @@ def _copy_episodes_metadata_and_stats(
if dst_meta.video_keys and src_dataset.meta.video_keys: if dst_meta.video_keys and src_dataset.meta.video_keys:
for key in dst_meta.video_keys: for key in dst_meta.video_keys:
if key in src_dataset.meta.features: if key in src_dataset.meta.features:
dst_meta.info.features[key]["info"] = deepcopy( dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
src_dataset.meta.info.features[key].get("info", {})
)
write_info(dst_meta.info, dst_meta.root) write_info(dst_meta.info, dst_meta.root)
+9 -38
View File
@@ -20,7 +20,6 @@ import datasets
import numpy as np import numpy as np
import pandas import pandas
import pandas as pd import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
@@ -271,49 +270,21 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
return items_dict return items_dict
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
"""Write ``table`` with one parquet row group per episode (in episode order).
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
mirroring the recording writer. ``table`` must carry a contiguous
``episode_index`` column.
"""
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
try:
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
finally:
writer.close()
def to_parquet_with_hf_images( def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None: ) -> None:
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode. """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Images are embedded into the arrow table first (``ParquetWriter.write_table`` Args:
does not embed external image files like ``Dataset.to_parquet`` does). df: DataFrame to write to parquet.
``features`` types image columns as ``Image()`` in the parquet schema. path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
""" """
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds = embed_images(ds) ds.to_parquet(path)
table = ds.with_format("arrow")[:]
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
# No episode boundaries to align row groups to — keep a single write.
pq.write_table(table, str(path))
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
table = pa.Table.from_pandas(df, preserve_index=False)
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
pq.write_table(table, str(path))
def item_to_torch(item: dict) -> dict: def item_to_torch(item: dict) -> dict:
+7 -5
View File
@@ -201,6 +201,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self._requested_root = Path(root) if root else None self._requested_root = Path(root) if root else None
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
@@ -247,7 +249,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms=image_transforms, image_transforms=image_transforms,
return_uint8=self._return_uint8, return_uint8=self._return_uint8,
) )
self.image_transforms = image_transforms
# Load actual data # Load actual data
if force_cache_sync or not self.reader.try_load(): if force_cache_sync or not self.reader.try_load():
@@ -504,14 +505,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
def set_image_transforms(self, image_transforms: Callable | None) -> None: def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations.""" """Replace the transform applied to visual observations."""
self._ensure_reader().set_image_transforms(image_transforms) if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms self.image_transforms = image_transforms
if self.reader is not None:
self.reader._image_transforms = image_transforms
def clear_image_transforms(self) -> None: def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations.""" """Remove the transform applied to visual observations."""
if self.reader is not None: self.set_image_transforms(None)
self.reader.set_image_transforms(None)
self.image_transforms = None
# ── Hub methods (stay on facade) ────────────────────────────────── # ── Hub methods (stay on facade) ──────────────────────────────────
+3 -5
View File
@@ -70,21 +70,19 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]], initial_features: dict[PipelineFeatureType, dict[str, Any]],
*, *,
use_videos: bool = True, use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None, patterns: Sequence[str] | None = None,
) -> dict[str, dict]: ) -> dict[str, dict]:
""" """
Aggregates and filters pipeline features to create a dataset-ready features dictionary. Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `exclude_images` and `patterns`, and finally (image or state), filters them based on `use_videos` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset. formats them for use with a Hugging Face LeRobot Dataset.
Args: Args:
pipeline: The DataProcessorPipeline to apply. pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations. initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image". use_videos: If False, image features are excluded.
exclude_images: If True, image features are dropped entirely from the output.
patterns: A sequence of regex patterns to filter action and state features. patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter. Image features are not affected by this filter.
@@ -122,7 +120,7 @@ def aggregate_pipeline_dataset_features(
) )
# 2. Apply filtering rules. # 2. Apply filtering rules.
if is_image and exclude_images: if is_image and not use_videos:
continue continue
if not is_image and not should_keep(key, compiled_patterns): if not is_image and not should_keep(key, compiled_patterns):
continue continue
+39 -4
View File
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
import packaging import packaging
import safetensors import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
@@ -129,10 +129,43 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if not getattr(cls, "name", None): if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'") raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None: def save_pretrained(
self,
save_directory: str | Path,
*,
state_dict: dict[str, Tensor] | None = None,
repo_id: str | None = None,
push_to_hub: bool = False,
card_kwargs: dict | None = None,
**push_to_hub_kwargs,
) -> str | None:
"""Save the policy to a directory (and optionally push to the Hub).
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
return sharded tensors, so the caller gathers the full state dict via a cross-rank
collective and passes it here for `_save_pretrained` to write directly.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
self._save_pretrained(save_directory, state_dict=state_dict)
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return None
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
self.config._save_pretrained(save_directory) self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) if state_dict is None:
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
return
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
@classmethod @classmethod
def from_pretrained( def from_pretrained(
@@ -270,6 +303,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self, self,
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
peft_model=None, peft_model=None,
state_dict: dict[str, Tensor] | None = None,
): ):
api = HfApi() api = HfApi()
repo_id = api.create_repo( repo_id = api.create_repo(
@@ -287,7 +321,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
peft_model.save_pretrained(saved_path) peft_model.save_pretrained(saved_path)
self.config.save_pretrained(saved_path) self.config.save_pretrained(saved_path)
else: else:
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors # Calls _save_pretrained and stores model tensors
self.save_pretrained(saved_path, state_dict=state_dict)
card = self.generate_model_card( card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
@@ -18,8 +18,7 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiOpenArmFollower(BimanualMixin, Robot): class BiOpenArmFollower(Robot):
""" """
Bimanual OpenArm Follower Arms Bimanual OpenArm Follower Arms
""" """
@@ -40,17 +39,15 @@ class BiOpenArmFollower(BimanualMixin, Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their # Top-level cameras are distributed evenly: each arm's OpenArmFollower
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`). # will only open the cameras assigned to it. Per-arm cameras are used
self._top_level_cam_keys = set(config.cameras) # as fallback when top-level cameras are empty.
_collisions = self._top_level_cam_keys & set( if config.cameras:
config.left_arm_config.cameras left_cameras = config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras) right_cameras = {}
if _collisions: else:
raise ValueError( left_cameras = config.left_arm_config.cameras
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}" right_cameras = config.right_arm_config.cameras
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = OpenArmFollowerConfig( left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
@@ -59,7 +56,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque, use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras, cameras=left_cameras,
side=config.left_arm_config.side, side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface, can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd, use_can_fd=config.left_arm_config.use_can_fd,
@@ -78,7 +75,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque, use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target, max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras, cameras=right_cameras,
side=config.right_arm_config.side, side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface, can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd, use_can_fd=config.right_arm_config.use_can_fd,
@@ -98,19 +95,22 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@property @property
def _motors_ft(self) -> dict[str, type]: def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return { return {
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()}, **{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()}, **{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
} }
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {} # Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
for k, v in self.left_arm._cameras_ft.items(): # "right_wrist"), so we merge them directly — unlike motors which need the
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v # left_/right_ prefix to disambiguate identical per-arm joint names.
for k, v in self.right_arm._cameras_ft.items(): return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
out[f"right_{k}"] = v
return out
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -120,6 +120,27 @@ class BiOpenArmFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors." "Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -127,15 +148,21 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {} obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed. # Camera keys that should NOT get the arm prefix (they already have unique names)
for key, value in self.left_arm.get_observation().items(): left_cam_keys = set(self.left_arm.cameras.keys())
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value right_cam_keys = set(self.right_arm.cameras.keys())
# Add "right_" prefix # Right first, then left — matches the teleoperator (OpenArmMini) ordering
for key, value in self.right_arm.get_observation().items(): # and the dataset feature names recorded during data collection.
obs_dict[f"right_{key}"] = value right_obs = self.right_arm.get_observation()
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict return obs_dict
@@ -162,4 +189,9 @@ class BiOpenArmFollower(BimanualMixin, Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()} prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right} return {**prefixed_sent_action_right, **prefixed_sent_action_left}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -32,7 +32,5 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras not attached to a specific side. Keys are kept as-is in # Top-level cameras shared across both arms.
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict) cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.decorators import check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiRebotB601Follower(BimanualMixin, Robot): class BiRebotB601Follower(Robot):
"""Bimanual Seeed Studio reBot B601-DM follower. """Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and Composes two single-arm :class:`RebotB601Follower` instances. Observation and
@@ -42,18 +41,6 @@ class BiRebotB601Follower(BimanualMixin, Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = RebotB601FollowerRobotConfig( left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir, calibration_dir=config.calibration_dir,
@@ -62,7 +49,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
dm_serial_baud=config.left_arm_config.dm_serial_baud, dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras, cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids, motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity, pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio, gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
@@ -99,12 +86,10 @@ class BiRebotB601Follower(BimanualMixin, Robot):
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {} return {
for k, v in self.left_arm._cameras_ft.items(): **{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v **{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
for k, v in self.right_arm._cameras_ft.items(): }
out[f"right_{k}"] = v
return out
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -114,13 +99,32 @@ class BiRebotB601Follower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {} obs_dict = {}
for k, v in self.left_arm.get_observation().items(): obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
for k, v in self.right_arm.get_observation().items():
obs_dict[f"right_{k}"] = v
return obs_dict return obs_dict
@check_if_not_connected @check_if_not_connected
@@ -139,3 +143,8 @@ class BiRebotB601Follower(BimanualMixin, Robot):
**{f"left_{k}": v for k, v in sent_action_left.items()}, **{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()}, **{f"right_{k}": v for k, v in sent_action_right.items()},
} }
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass
from lerobot.cameras import CameraConfig
from ..config import RobotConfig from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig from ..rebot_b601_follower import RebotB601FollowerConfig
@@ -29,8 +27,3 @@ class BiRebotB601FollowerConfig(RobotConfig):
left_arm_config: RebotB601FollowerConfig left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig right_arm_config: RebotB601FollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.decorators import check_if_not_connected
from ..robot import Robot from ..robot import Robot
from ..so_follower import SOFollower, SOFollowerRobotConfig from ..so_follower import SOFollower, SOFollowerRobotConfig
@@ -28,7 +27,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiSOFollower(BimanualMixin, Robot): class BiSOFollower(Robot):
""" """
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio [Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
""" """
@@ -40,18 +39,6 @@ class BiSOFollower(BimanualMixin, Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = SOFollowerRobotConfig( left_arm_config = SOFollowerRobotConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir, calibration_dir=config.calibration_dir,
@@ -59,7 +46,7 @@ class BiSOFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
use_degrees=config.left_arm_config.use_degrees, use_degrees=config.left_arm_config.use_degrees,
cameras=left_arm_cameras, cameras=config.left_arm_config.cameras,
) )
right_arm_config = SOFollowerRobotConfig( right_arm_config = SOFollowerRobotConfig(
@@ -90,12 +77,13 @@ class BiSOFollower(BimanualMixin, Robot):
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {} left_arm_cameras_ft = self.left_arm._cameras_ft
for k, v in self.left_arm._cameras_ft.items(): right_arm_cameras_ft = self.right_arm._cameras_ft
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items(): return {
out[f"right_{k}"] = v **{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
return out **{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -105,21 +93,42 @@ class BiSOFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
self.left_arm.setup_motors() self.left_arm.setup_motors()
self.right_arm.setup_motors() self.right_arm.setup_motors()
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {} obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed. # Add "left_" prefix
for key, value in self.left_arm.get_observation().items(): left_obs = self.left_arm.get_observation()
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Add "right_" prefix # Add "right_" prefix
for key, value in self.right_arm.get_observation().items(): right_obs = self.right_arm.get_observation()
obs_dict[f"right_{key}"] = value obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
return obs_dict return obs_dict
@@ -142,3 +151,8 @@ class BiSOFollower(BimanualMixin, Robot):
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right} return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass
from lerobot.cameras import CameraConfig
from ..config import RobotConfig from ..config import RobotConfig
from ..so_follower import SOFollowerConfig from ..so_follower import SOFollowerConfig
@@ -29,8 +27,3 @@ class BiSOFollowerConfig(RobotConfig):
left_arm_config: SOFollowerConfig left_arm_config: SOFollowerConfig
right_arm_config: SOFollowerConfig right_arm_config: SOFollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
-1
View File
@@ -54,7 +54,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
@@ -57,7 +57,6 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
gamepad, gamepad,
-1
View File
@@ -137,7 +137,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
-1
View File
@@ -174,7 +174,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
@@ -41,7 +41,6 @@ from lerobot.robots import ( # noqa: F401
) )
from lerobot.teleoperators import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
koch_leader, koch_leader,
@@ -89,7 +89,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
gamepad, gamepad,
+16 -4
View File
@@ -189,6 +189,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
require_package("accelerate", extra="training") require_package("accelerate", extra="training")
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
cfg.validate() cfg.validate()
@@ -197,8 +198,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation # We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None: if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training). # Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
@@ -558,20 +557,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
# All ranks must call get_state_dict; rank 0 gets the
# full state dict, others get an empty dict.
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
model_state_dict = accelerator.get_state_dict(policy)
if is_main_process: if is_main_process:
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
if is_fsdp:
# TODO(fsdp): sharded optimizer-state save/resume is not wired up yet.
logging.warning(
"FSDP checkpoint: saving model weights only (optimizer state skipped; "
"resume-from-checkpoint not supported under FSDP yet)."
)
save_checkpoint( save_checkpoint(
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
step=step, step=step,
cfg=cfg, cfg=cfg,
policy=accelerator.unwrap_model(policy), policy=accelerator.unwrap_model(policy),
optimizer=optimizer, optimizer=None if is_fsdp else optimizer,
scheduler=lr_scheduler, scheduler=lr_scheduler,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
num_processes=accelerator.num_processes, num_processes=accelerator.num_processes,
batch_size=cfg.batch_size, batch_size=cfg.batch_size,
model_state_dict=model_state_dict,
) )
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
if wandb_logger: if wandb_logger:
@@ -634,6 +644,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if eval_env: if eval_env:
close_envs(eval_env) close_envs(eval_env)
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None
if is_main_process: if is_main_process:
logging.info("End of training") logging.info("End of training")
@@ -643,7 +655,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if not cfg.is_reward_model_training and cfg.policy.use_peft: if not cfg.is_reward_model_training and cfg.policy.use_peft:
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model) unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
else: else:
unwrapped_model.push_model_to_hub(cfg) unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
preprocessor.push_to_hub(active_cfg.repo_id) preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id) postprocessor.push_to_hub(active_cfg.repo_id)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
@@ -28,7 +27,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiOpenArmLeader(BimanualMixin, Teleoperator): class BiOpenArmLeader(Teleoperator):
""" """
Bimanual OpenArm Leader Arms Bimanual OpenArm Leader Arms
""" """
@@ -87,6 +86,27 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors." "Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -109,3 +129,8 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback # TODO: Implement force feedback
raise NotImplementedError raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_leader") @TeleoperatorConfig.register_subclass("bi_openarm_leader")
@dataclass @dataclass
class BiOpenArmLeaderConfig(TeleoperatorConfig): class BiOpenArmLeaderConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Leader teleoperators.""" """Configuration class for Bi OpenArm Follower robots."""
left_arm_config: OpenArmLeaderConfigBase left_arm_config: OpenArmLeaderConfigBase
right_arm_config: OpenArmLeaderConfigBase right_arm_config: OpenArmLeaderConfigBase
@@ -1,20 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_mini import BiOpenArmMini
from .config_bi_openarm_mini import BiOpenArmMiniConfig
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
@@ -1,101 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
from ..teleoperator import Teleoperator
from .config_bi_openarm_mini import BiOpenArmMiniConfig
logger = logging.getLogger(__name__)
class BiOpenArmMini(BimanualMixin, Teleoperator):
"""Bimanual OpenArm Mini teleoperator.
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
bimanual leader can teleoperate a bimanual OpenArm follower.
"""
config_class = BiOpenArmMiniConfig
name = "bi_openarm_mini"
def __init__(self, config: BiOpenArmMiniConfig):
super().__init__(config)
self.config = config
# `side` is forced to match left/right regardless of what the user passed
# on the per-arm base config — the bimanual wrapper owns the side semantics.
left_arm_config = OpenArmMiniConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
side="left",
use_degrees=config.left_arm_config.use_degrees,
)
right_arm_config = OpenArmMiniConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
side="right",
use_degrees=config.right_arm_config.use_degrees,
)
self.left_arm = OpenArmMini(left_arm_config)
self.right_arm = OpenArmMini(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
}
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
action: RobotAction = {}
for k, v in self.left_arm.get_action().items():
action[f"left_{k}"] = v
for k, v in self.right_arm.get_action().items():
action[f"right_{k}"] = v
return action
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
if left_fb:
self.left_arm.send_feedback(left_fb)
if right_fb:
self.right_arm.send_feedback(right_fb)
@@ -1,29 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
from ..openarm_mini import OpenArmMiniConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
@dataclass
class BiOpenArmMiniConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Mini teleoperators."""
left_arm_config: OpenArmMiniConfigBase
right_arm_config: OpenArmMiniConfigBase
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .bi_rebot_102_leader import BiRebot102Leader from .bi_rebot_102_leader import BiRebotArm102Leader
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"] __all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
@@ -18,17 +18,16 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.decorators import check_if_not_connected
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiRebot102Leader(BimanualMixin, Teleoperator): class BiRebotArm102Leader(Teleoperator):
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader. """Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
@@ -36,10 +35,10 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
leader can teleoperate a bimanual reBot B601 follower. leader can teleoperate a bimanual reBot B601 follower.
""" """
config_class = BiRebot102LeaderConfig config_class = BiRebotArm102LeaderConfig
name = "bi_rebot_102_leader" name = "bi_rebot_102_leader"
def __init__(self, config: BiRebot102LeaderConfig): def __init__(self, config: BiRebotArm102LeaderConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
@@ -77,6 +76,27 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected @check_if_not_connected
def get_action(self) -> RobotAction: def get_action(self) -> RobotAction:
action_dict = {} action_dict = {}
@@ -86,3 +106,8 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.") raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader") @TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
@dataclass @dataclass
class BiRebot102LeaderConfig(TeleoperatorConfig): class BiRebotArm102LeaderConfig(TeleoperatorConfig):
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator.""" """Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
left_arm_config: RebotArm102LeaderConfig left_arm_config: RebotArm102LeaderConfig
@@ -17,9 +17,7 @@
import logging import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..so_leader import SOLeader, SOLeaderTeleopConfig from ..so_leader import SOLeader, SOLeaderTeleopConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
@@ -28,7 +26,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiSOLeader(BimanualMixin, Teleoperator): class BiSOLeader(Teleoperator):
""" """
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio [Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
""" """
@@ -69,12 +67,33 @@ class BiSOLeader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
self.left_arm.setup_motors() self.left_arm.setup_motors()
self.right_arm.setup_motors() self.right_arm.setup_motors()
@check_if_not_connected @check_if_not_connected
def get_action(self) -> RobotAction: def get_action(self) -> dict[str, float]:
action_dict = {} action_dict = {}
# Add "left_" prefix # Add "left_" prefix
@@ -90,3 +109,8 @@ class BiSOLeader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback # TODO: Implement force feedback
raise NotImplementedError raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase from .config_openarm_mini import OpenArmMiniConfig
from .openarm_mini import OpenArmMini from .openarm_mini import OpenArmMini
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"] __all__ = ["OpenArmMini", "OpenArmMiniConfig"]
@@ -19,21 +19,12 @@ from dataclasses import dataclass
from ..config import TeleoperatorConfig from ..config import TeleoperatorConfig
@dataclass
class OpenArmMiniConfigBase:
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
port: str
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
# during readout. If `None`, no flipping is applied.
side: str | None = None
use_degrees: bool = True
@TeleoperatorConfig.register_subclass("openarm_mini") @TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass @dataclass
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase): class OpenArmMiniConfig(TeleoperatorConfig):
pass """Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
port_right: str = "/dev/ttyUSB0"
port_left: str = "/dev/ttyUSB1"
use_degrees: bool = True
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Per-side motor direction flips applied during readout. # Motors whose direction is inverted during readout
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = { RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"], LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
}
# Leader joint 6 follower joint 7 (symmetric — its own inverse). # Leader joint 6 maps to follower joint 7 and vice versa
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"} JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65 GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator): class OpenArmMini(Teleoperator):
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper). """
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these. Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
""" """
config_class = OpenArmMiniConfig config_class = OpenArmMiniConfig
@@ -56,12 +56,9 @@ class OpenArmMini(Teleoperator):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
norm_mode_body = MotorNormMode.DEGREES norm_mode_body = MotorNormMode.DEGREES
motors = {
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body), "joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body), "joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body), "joint_3": Motor(3, "sts3215", norm_mode_body),
@@ -72,15 +69,46 @@ class OpenArmMini(Teleoperator):
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100), "gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
} }
self.bus = FeetechMotorsBus( motors_left = {
port=self.config.port, "joint_1": Motor(1, "sts3215", norm_mode_body),
motors=motors, "joint_2": Motor(2, "sts3215", norm_mode_body),
calibration=self.calibration, "joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
) )
@property @property
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors} # Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property @property
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
@@ -88,12 +116,14 @@ class OpenArmMini(Teleoperator):
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self.bus.is_connected return self.bus_right.is_connected and self.bus_left.is_connected
@check_if_already_connected @check_if_already_connected
def connect(self, calibrate: bool = True) -> None: def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting arm on {self.config.port}...") logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus.connect() self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate: if calibrate:
self.calibrate() self.calibrate()
@@ -103,14 +133,14 @@ class OpenArmMini(Teleoperator):
@property @property
def is_calibrated(self) -> bool: def is_calibrated(self) -> bool:
return self.bus.is_calibrated return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None: def calibrate(self) -> None:
""" """
Run calibration procedure for a single OpenArm Mini arm. Run calibration procedure for OpenArm Mini.
1. Disable torque 1. Disable torque
2. Ask user to position arm in hanging position with gripper closed 2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position via half-turn homing 3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions) 4. Interactive gripper calibration (open/close positions)
5. Save calibration 5. Save calibration
@@ -122,51 +152,70 @@ class OpenArmMini(Teleoperator):
) )
if user_input.strip().lower() != "c": if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}") logger.info(f"Using existing calibration for {self.id}")
self.bus.write_calibration(self.calibration) cal_right = {
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return return
logger.info(f"\nRunning calibration for {self}") logger.info(f"\nRunning calibration for {self}")
self.bus.disable_torque() self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
logger.info("Setting Phase to 12 for all motors...") self._save_calibration()
for motor in self.bus.motors: print(f"\nCalibration complete and saved to {self.calibration_fpath}")
self.bus.write("Phase", motor, 12)
for motor in self.bus.motors: def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) """Calibrate a single arm with Feetech motors."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input( input(
"\nCalibration: Zero Position\n" f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n" "Position the arm in the following configuration:\n"
" - Arm hanging straight down\n" " - Arm hanging straight down\n"
" - Gripper closed\n" " - Gripper closed\n"
"Press ENTER when ready..." "Press ENTER when ready..."
) )
homing_offsets = self.bus.set_half_turn_homings() homing_offsets = bus.set_half_turn_homings()
logger.info("Arm zero position set.") logger.info(f"{arm_name.capitalize()} arm zero position set.")
print("\nSetting motor ranges\n") print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
if self.calibration is None: if self.calibration is None:
self.calibration = {} self.calibration = {}
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model] motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
max_res = motor_resolution - 1 max_res = motor_resolution - 1
for motor_name, motor in self.bus.motors.items(): for motor_name, motor in bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper": if motor_name == "gripper":
input( input(
"\nGripper Calibration\n" f"\nGripper Calibration ({arm_name.upper()} arm)\n"
"Step 1: CLOSE the gripper fully\n" f"Step 1: CLOSE the gripper fully\n"
"Press ENTER when gripper is closed..." f"Press ENTER when gripper is closed..."
) )
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False) closed_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}") logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...") input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = self.bus.read("Present_Position", motor_name, normalize=False) open_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}") logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos: if closed_pos < open_pos:
@@ -179,16 +228,16 @@ class OpenArmMini(Teleoperator):
drive_mode = 1 drive_mode = 1
logger.info( logger.info(
f" {motor_name}: range set to [{range_min}, {range_max}] " f" {prefixed_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})" f"(0=closed, 100=open, drive_mode={drive_mode})"
) )
else: else:
range_min = 0 range_min = 0
range_max = max_res range_max = max_res
drive_mode = 0 drive_mode = 0
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)") logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[motor_name] = MotorCalibration( self.calibration[prefixed_name] = MotorCalibration(
id=motor.id, id=motor.id,
drive_mode=drive_mode, drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name], homing_offset=homing_offsets[motor_name],
@@ -196,68 +245,108 @@ class OpenArmMini(Teleoperator):
range_max=range_max, range_max=range_max,
) )
self.bus.write_calibration(self.calibration) cal_for_bus = {
self._save_calibration() k.replace(f"{arm_name}_", ""): v
print(f"\nCalibration complete and saved to {self.calibration_fpath}") for k, v in self.calibration.items()
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None: def configure(self) -> None:
self.bus.disable_torque() self.bus_right.disable_torque()
self.bus.configure_motors() self.bus_right.configure_motors()
for motor in self.bus.motors: for motor in self.bus_right.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None: def setup_motors(self) -> None:
for motor in reversed(self.bus.motors): print("\nSetting up RIGHT arm motors...")
input(f"Connect the controller board to the '{motor}' motor only and press enter.") for motor in reversed(self.bus_right.motors):
self.bus.setup_motor(motor) input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") self.bus_right.setup_motor(motor)
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected @check_if_not_connected
def get_action(self) -> RobotAction: def get_action(self) -> RobotAction:
"""Get current action (read positions from all motors).""" """Get current action from both arms (read positions from all motors)."""
start = time.perf_counter() start = time.perf_counter()
positions = self.bus.sync_read("Present_Position") right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa. # Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
# Per-side direction flip is applied based on the configured `side`.
action: dict[str, Any] = {} action: dict[str, Any] = {}
for motor, val in positions.items(): for motor, val in right_positions.items():
target = JOINT_REMAP.get(motor, motor) target = JOINT_REMAP.get(motor, motor)
if motor == "gripper": if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65° # Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else: else:
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3 dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms") logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action return action
def enable_torque(self) -> None: def enable_torque(self) -> None:
self.bus.enable_torque() """Enable torque on both arms for position control."""
self.bus_right.enable_torque()
self.bus_left.enable_torque()
def disable_torque(self) -> None: def disable_torque(self) -> None:
self.bus.disable_torque() """Disable torque on both arms for free movement."""
self.bus_right.disable_torque()
self.bus_left.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None: def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic).""" """Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
goals: dict[str, float] = {} right_goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
for key, val in positions.items(): for key, val in positions.items():
if not key.endswith(".pos"): if not key.endswith(".pos"):
continue continue
base = key.removesuffix(".pos") motor_name = key.removesuffix(".pos")
# JOINT_REMAP is symmetric (its own inverse). if motor_name.startswith("right_"):
target = JOINT_REMAP.get(base, base) base = motor_name.removeprefix("right_")
if base == "gripper": # Reverse remap: follower joint_7 → leader joint_6 and vice versa
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100 target = JOINT_REMAP_REVERSE.get(base, base)
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES if base == "gripper":
else: # Convert robot degrees to teleop 0-100: 0°→0, -65°→100
# Un-flip using the ORIGINAL motor name (target = leader motor) right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
goals[target] = -val if target in self._motors_to_flip else val else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
if goals: if right_goals:
self.bus.sync_write("Goal_Position", goals) self.bus_right.sync_write("Goal_Position", right_goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
@check_if_not_connected @check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
@@ -265,5 +354,6 @@ class OpenArmMini(Teleoperator):
@check_if_not_connected @check_if_not_connected
def disconnect(self) -> None: def disconnect(self) -> None:
self.bus.disconnect() self.bus_right.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.") logger.info(f"{self} disconnected.")
+2 -6
View File
@@ -99,18 +99,14 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .openarm_mini import OpenArmMini from .openarm_mini import OpenArmMini
return OpenArmMini(config) return OpenArmMini(config)
elif config.type == "bi_openarm_mini":
from .bi_openarm_mini import BiOpenArmMini
return BiOpenArmMini(config)
elif config.type == "rebot_102_leader": elif config.type == "rebot_102_leader":
from .rebot_102_leader import RebotArm102Leader from .rebot_102_leader import RebotArm102Leader
return RebotArm102Leader(config) return RebotArm102Leader(config)
elif config.type == "bi_rebot_102_leader": elif config.type == "bi_rebot_102_leader":
from .bi_rebot_102_leader import BiRebot102Leader from .bi_rebot_102_leader import BiRebotArm102Leader
return BiRebot102Leader(config) return BiRebotArm102Leader(config)
else: else:
try: try:
return cast("Teleoperator", make_device_from_device_class(config)) return cast("Teleoperator", make_device_from_device_class(config))
-63
View File
@@ -1,63 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
class BimanualMixin:
"""Lifecycle delegation for bimanual robots and teleoperators.
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
their own ``__init__``. They retain ownership of feature dicts and the
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
``send_feedback``), which vary per-embodiment.
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
take precedence in the MRO::
class BiFooFollower(BimanualMixin, Robot): ...
"""
left_arm: Any
right_arm: Any
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
-73
View File
@@ -28,7 +28,6 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])") pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import pyarrow.parquet as pq # noqa: E402 import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402 from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
@@ -345,78 +344,6 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
assert len(dataset) == 24 assert len(dataset) == 24
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
for ep, length in enumerate(episode_lengths):
episode_index += [ep] * length
frame_index += list(range(length))
timestamp += [round(i / fps, 6) for i in range(length)]
task_index += [0] * length
subtask_index += [0] * length # legacy column the writer must drop
pd.DataFrame(
{
"episode_index": episode_index,
"frame_index": frame_index,
"timestamp": timestamp,
"task_index": task_index,
"subtask_index": subtask_index,
}
).to_parquet(data_dir / "file-000.parquet", index=False)
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
write_tasks(tasks_df, root)
write_json(
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
root / "meta" / "info.json",
)
return root
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
"""Rewriting a packed shard must keep one row group per episode, not collapse
every episode into a single giant row group."""
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
shard = root / "data" / "chunk-000" / "file-000.parquet"
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
staging_dir = tmp_path / "stage"
for ep in range(len(episode_lengths)):
_stage_episode(
staging_dir,
ep,
plan=[
{
"role": "assistant",
"content": f"subtask for ep {ep}",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
)
records = list(iter_episodes(root))
LanguageColumnsWriter().write_all(records, staging_dir, root)
# One row group per episode, with row counts matching the episode lengths.
md = pq.ParquetFile(shard).metadata
assert md.num_row_groups == len(episode_lengths)
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
# Language columns are still present after the per-episode rewrite.
table = pq.read_table(shard)
assert "language_persistent" in table.column_names
assert "language_events" in table.column_names
def test_speech_atom_shape_matches_plan_spec() -> None: def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!") atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant" assert atom["role"] == "assistant"
-55
View File
@@ -32,26 +32,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_REPO_ID
def assert_data_shards_one_row_group_per_episode(root):
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
import pyarrow.parquet as pq
shards = sorted((root / "data").rglob("*.parquet"))
assert shards, f"no data shards found under {root}/data"
n_episodes = 0
for shard in shards:
pf = pq.ParquetFile(shard)
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
assert pf.metadata.num_row_groups == len(set(episodes)), shard
for i in range(pf.metadata.num_row_groups):
rg_episodes = set(
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
)
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
n_episodes += len(set(episodes))
return n_episodes
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames): def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated.""" """Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, ( assert aggr_ds.num_episodes == expected_episodes, (
@@ -586,41 +566,6 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
) )
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
Covers both the non-image (``df.to_parquet``) and image
(``to_parquet_with_hf_images``) write branches, including the merge-into-
existing-file branch via a low file-size threshold that forces packing.
"""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "rg_0",
repo_id=f"{DUMMY_REPO_ID}_rg_0",
total_episodes=3,
total_frames=60,
use_videos=use_videos,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "rg_1",
repo_id=f"{DUMMY_REPO_ID}_rg_1",
total_episodes=4,
total_frames=80,
use_videos=use_videos,
)
aggr_root = tmp_path / "rg_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
aggr_root=aggr_root,
)
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema. """Test aggregation of image-based datasets preserves HuggingFace Image schema.
+1 -16
View File
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
from lerobot.transforms import ImageTransforms, ImageTransformsConfig from lerobot.transforms import ImageTransforms, ImageTransformsConfig
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
from lerobot.utils.feature_utils import hw_to_dataset_features from lerobot.utils.feature_utils import hw_to_dataset_features
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel
@@ -133,21 +133,6 @@ def test_dataset_feature_with_forward_slash_raises_error():
) )
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
)
dataset_copy = empty_lerobot_dataset_factory(
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
)
original_shape = dataset.meta.info.features["state"]["shape"]
dataset_copy.meta.info.features["state"]["shape"] = (999,)
assert dataset.meta.info.features["state"]["shape"] == original_shape
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
+24
View File
@@ -23,6 +23,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from packaging import version from packaging import version
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -300,6 +301,29 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
policy_cls = get_policy_class("act")
policy_cfg = make_policy_config("act")
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
policy.to(policy_cfg.device)
save_dir = tmp_path / "fsdp_state_dict"
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
# A single, unsharded safetensors file (no sharded set + index).
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("multikey", [True, False]) @pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool): def test_multikey_construction(multikey: bool):
""" """
+3 -21
View File
@@ -2370,32 +2370,14 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features( out = aggregate_pipeline_dataset_features(
pipeline=rp, pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # images kept, stored as "image" dtype use_videos=False, # expect "image" dtype
patterns=None, patterns=None,
) )
key = f"{OBS_IMAGES}.back" key = f"{OBS_IMAGES}.back"
key_front = f"{OBS_IMAGES}.front" key_front = f"{OBS_IMAGES}.front"
assert key in out assert key not in out
assert key_front in out assert key_front not in out
assert out[key]["dtype"] == "image"
assert out[key_front]["dtype"] == "image"
assert out[key]["shape"] == initial["back"]
def test_aggregate_images_excluded():
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
initial = {"back": (480, 640, 3)}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
exclude_images=True,
patterns=None,
)
assert f"{OBS_IMAGES}.back" not in out
assert f"{OBS_IMAGES}.front" not in out
def test_aggregate_images_when_use_videos_true(): def test_aggregate_images_when_use_videos_true():
+3 -3
View File
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
from lerobot.teleoperators.rebot_102_leader import ( from lerobot.teleoperators.rebot_102_leader import (
RebotArm102Leader, RebotArm102Leader,
RebotArm102LeaderConfig, RebotArm102LeaderConfig,
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
def test_bimanual_prefixes_features(): def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None): with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebot102LeaderConfig( cfg = BiRebotArm102LeaderConfig(
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"), left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"), right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
) )
teleop = BiRebot102Leader(cfg) teleop = BiRebotArm102Leader(cfg)
assert any(k.startswith("left_") for k in teleop.action_features) assert any(k.startswith("left_") for k in teleop.action_features)
assert any(k.startswith("right_") for k in teleop.action_features) assert any(k.startswith("right_") for k in teleop.action_features)
assert "left_gripper.pos" in teleop.action_features assert "left_gripper.pos" in teleop.action_features
Generated
+900 -949
View File
File diff suppressed because it is too large Load Diff