mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e0deff3ab | |||
| b42d124007 | |||
| 3ce50c3468 | |||
| 44fd3c0a0e | |||
| 0483afc743 |
@@ -136,7 +136,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
|||||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||||
- **[LeLab](https://github.com/huggingface/leLab):** A web interface for LeRobot — teleoperate, calibrate, record datasets, replay, and train your SO arm from the browser, no CLI required.
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
|
|||||||
|
|
||||||
**Compatible teleoperators:**
|
**Compatible teleoperators:**
|
||||||
|
|
||||||
- `bi_openarm_mini` - Bimanual OpenArm Mini
|
- `openarm_mini` - OpenArm Mini
|
||||||
- `so_leader` - SO100 / SO101 leader arm
|
- `so_leader` - SO100 / SO101 leader arm
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
|
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
|
|||||||
--robot.right_arm_config.port=can0 \
|
--robot.right_arm_config.port=can0 \
|
||||||
--robot.right_arm_config.side=right \
|
--robot.right_arm_config.side=right \
|
||||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||||
--teleop.type=bi_openarm_mini \
|
--teleop.type=openarm_mini \
|
||||||
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
--teleop.port_left=/dev/ttyACM0 \
|
||||||
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
--teleop.port_right=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/rollout_hil_dataset \
|
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
|
|||||||
--robot.right_arm_config.port=can0 \
|
--robot.right_arm_config.port=can0 \
|
||||||
--robot.right_arm_config.side=right \
|
--robot.right_arm_config.side=right \
|
||||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||||
--teleop.type=bi_openarm_mini \
|
--teleop.type=openarm_mini \
|
||||||
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
--teleop.port_left=/dev/ttyACM0 \
|
||||||
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
--teleop.port_right=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ lerobot-rollout \
|
|||||||
--strategy.num_episodes=20 \
|
--strategy.num_episodes=20 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--robot.type=bi_openarm_follower \
|
--robot.type=bi_openarm_follower \
|
||||||
--teleop.type=bi_openarm_mini \
|
--teleop.type=openarm_mini \
|
||||||
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||||
--dataset.single_task="Fold the T-shirt"
|
--dataset.single_task="Fold the T-shirt"
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -113,6 +113,52 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
|
|||||||
--policy=act
|
--policy=act
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Training Large Models with FSDP
|
||||||
|
|
||||||
|
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
|
||||||
|
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
|
||||||
|
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
|
||||||
|
|
||||||
|
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||||
|
--policy.type=<your_policy> \
|
||||||
|
--output_dir=outputs/train/my_policy_fsdp
|
||||||
|
```
|
||||||
|
|
||||||
|
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
distributed_type: FSDP
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 4
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 1
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
|
||||||
|
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
```
|
||||||
|
|
||||||
|
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
|
||||||
|
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
|
||||||
|
optimizer before `accelerator.prepare()`.
|
||||||
|
|
||||||
|
### FSDP checkpoints
|
||||||
|
|
||||||
|
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
|
||||||
|
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two thigs to look out for:
|
||||||
|
|
||||||
|
- With mixed precision, (`bf16`/`fp16`) FSDP keeps an fp32 master copy, so the checkpoint is fp32
|
||||||
|
(~2× the bf16 size on disk) and is cast back to the policy dtype on load.
|
||||||
|
- **Optimizer state is not saved under FSDP**, so **resume-from-checkpoint is not supported**.
|
||||||
|
Saved weights are fully usable for evaluation and fine-tuning.
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ from typing import Any
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
|
|
||||||
from lerobot.datasets.language import (
|
from lerobot.datasets.language import (
|
||||||
EVENT_ONLY_STYLES,
|
EVENT_ONLY_STYLES,
|
||||||
LANGUAGE_EVENTS,
|
LANGUAGE_EVENTS,
|
||||||
@@ -275,11 +274,12 @@ class LanguageColumnsWriter:
|
|||||||
new_table = self._materialize_table(
|
new_table = self._materialize_table(
|
||||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||||
)
|
)
|
||||||
# Re-emit one row group per episode (a bulk pq.write_table would collapse
|
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||||
# them into one). Write to a sibling tmp path and atomically rename so a
|
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||||
# crash mid-write can't leave a half-written shard.
|
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||||
|
# Windows when source and target sit on the same filesystem.
|
||||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||||
write_table_one_row_group_per_episode(new_table, tmp_path)
|
pq.write_table(new_table, tmp_path)
|
||||||
tmp_path.replace(path)
|
tmp_path.replace(path)
|
||||||
|
|
||||||
def _materialize_table(
|
def _materialize_table(
|
||||||
|
|||||||
@@ -442,12 +442,11 @@ class OpenCVCamera(Camera):
|
|||||||
|
|
||||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||||
"""
|
"""
|
||||||
stop_event = self.stop_event
|
if self.stop_event is None:
|
||||||
if stop_event is None:
|
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
raw_frame = self._read_from_hardware()
|
raw_frame = self._read_from_hardware()
|
||||||
processed_frame = self._postprocess_image(raw_frame)
|
processed_frame = self._postprocess_image(raw_frame)
|
||||||
|
|||||||
@@ -471,12 +471,11 @@ class RealSenseCamera(Camera):
|
|||||||
|
|
||||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||||
"""
|
"""
|
||||||
stop_event = self.stop_event
|
if self.stop_event is None:
|
||||||
if stop_event is None:
|
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
frame = self._read_from_hardware()
|
frame = self._read_from_hardware()
|
||||||
color_frame_raw = frame.get_color_frame()
|
color_frame_raw = frame.get_color_frame()
|
||||||
|
|||||||
@@ -246,12 +246,11 @@ class ZMQCamera(Camera):
|
|||||||
"""
|
"""
|
||||||
Internal loop run by the background thread for asynchronous reading.
|
Internal loop run by the background thread for asynchronous reading.
|
||||||
"""
|
"""
|
||||||
stop_event = self.stop_event
|
if self.stop_event is None:
|
||||||
if stop_event is None:
|
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
frame = self._read_from_hardware()
|
frame = self._read_from_hardware()
|
||||||
capture_time = time.perf_counter()
|
capture_time = time.perf_counter()
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ def save_checkpoint(
|
|||||||
postprocessor: PolicyProcessorPipeline | None = None,
|
postprocessor: PolicyProcessorPipeline | None = None,
|
||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
|
model_state_dict: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function creates the following directory structure:
|
"""This function creates the following directory structure:
|
||||||
|
|
||||||
@@ -127,9 +128,14 @@ def save_checkpoint(
|
|||||||
resume. Defaults to None (not recorded).
|
resume. Defaults to None (not recorded).
|
||||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||||
resume. Defaults to None (not recorded).
|
resume. Defaults to None (not recorded).
|
||||||
|
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
|
||||||
|
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
|
||||||
|
cross-rank collective and passes it here so rank 0 can write it directly. It holds
|
||||||
|
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
|
||||||
|
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
|
||||||
"""
|
"""
|
||||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||||
policy.save_pretrained(pretrained_dir)
|
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
|
||||||
cfg.save_pretrained(pretrained_dir)
|
cfg.save_pretrained(pretrained_dir)
|
||||||
if cfg.peft is not None:
|
if cfg.peft is not None:
|
||||||
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
|
|||||||
from .io_utils import (
|
from .io_utils import (
|
||||||
get_file_size_in_mb,
|
get_file_size_in_mb,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
to_parquet_one_row_group_per_episode,
|
|
||||||
to_parquet_with_hf_images,
|
to_parquet_with_hf_images,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
write_stats,
|
||||||
@@ -552,7 +551,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
|||||||
aggr_root=dst_meta.root,
|
aggr_root=dst_meta.root,
|
||||||
hf_features=hf_features,
|
hf_features=hf_features,
|
||||||
concatenate=concatenate_data,
|
concatenate=concatenate_data,
|
||||||
one_row_group_per_episode=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Record the mapping from source to actual destination
|
# Record the mapping from source to actual destination
|
||||||
@@ -630,7 +628,6 @@ def append_or_create_parquet_file(
|
|||||||
aggr_root: Path = None,
|
aggr_root: Path = None,
|
||||||
hf_features: datasets.Features | None = None,
|
hf_features: datasets.Features | None = None,
|
||||||
concatenate: bool = True,
|
concatenate: bool = True,
|
||||||
one_row_group_per_episode: bool = False,
|
|
||||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||||
|
|
||||||
@@ -648,8 +645,6 @@ def append_or_create_parquet_file(
|
|||||||
aggr_root: Root path for the aggregated dataset.
|
aggr_root: Root path for the aggregated dataset.
|
||||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||||
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||||
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
|
|
||||||
the episodes-metadata parquet (already one row per episode).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||||
@@ -662,8 +657,6 @@ def append_or_create_parquet_file(
|
|||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||||
elif one_row_group_per_episode:
|
|
||||||
to_parquet_one_row_group_per_episode(df, dst_path)
|
|
||||||
else:
|
else:
|
||||||
df.to_parquet(dst_path)
|
df.to_parquet(dst_path)
|
||||||
return idx, (dst_chunk, dst_file)
|
return idx, (dst_chunk, dst_file)
|
||||||
@@ -690,8 +683,6 @@ def append_or_create_parquet_file(
|
|||||||
|
|
||||||
if contains_images:
|
if contains_images:
|
||||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||||
elif one_row_group_per_episode:
|
|
||||||
to_parquet_one_row_group_per_episode(final_df, target_path)
|
|
||||||
else:
|
else:
|
||||||
final_df.to_parquet(target_path)
|
final_df.to_parquet(target_path)
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -710,7 +709,7 @@ class LeRobotDatasetMetadata:
|
|||||||
|
|
||||||
obj.root.mkdir(parents=True, exist_ok=False)
|
obj.root.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
features = {**deepcopy(features), **DEFAULT_FEATURES}
|
features = {**features, **DEFAULT_FEATURES}
|
||||||
_validate_feature_names(features)
|
_validate_feature_names(features)
|
||||||
|
|
||||||
obj.tasks = None
|
obj.tasks = None
|
||||||
|
|||||||
@@ -74,8 +74,6 @@ class DatasetReader:
|
|||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self._tolerance_s = tolerance_s
|
self._tolerance_s = tolerance_s
|
||||||
self._video_backend = video_backend
|
self._video_backend = video_backend
|
||||||
if image_transforms is not None and not callable(image_transforms):
|
|
||||||
raise TypeError("image_transforms must be callable or None.")
|
|
||||||
self._image_transforms = image_transforms
|
self._image_transforms = image_transforms
|
||||||
self._return_uint8 = return_uint8
|
self._return_uint8 = return_uint8
|
||||||
|
|
||||||
@@ -88,16 +86,6 @@ class DatasetReader:
|
|||||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||||
|
|
||||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
|
||||||
"""Replace the transform applied to visual observations."""
|
|
||||||
if image_transforms is not None and not callable(image_transforms):
|
|
||||||
raise TypeError("image_transforms must be callable or None.")
|
|
||||||
self._image_transforms = image_transforms
|
|
||||||
|
|
||||||
def clear_image_transforms(self) -> None:
|
|
||||||
"""Remove the transform applied to visual observations."""
|
|
||||||
self._image_transforms = None
|
|
||||||
|
|
||||||
def try_load(self) -> bool:
|
def try_load(self) -> bool:
|
||||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ import logging
|
|||||||
import shutil
|
import shutil
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -1102,9 +1101,7 @@ def _copy_episodes_metadata_and_stats(
|
|||||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||||
for key in dst_meta.video_keys:
|
for key in dst_meta.video_keys:
|
||||||
if key in src_dataset.meta.features:
|
if key in src_dataset.meta.features:
|
||||||
dst_meta.info.features[key]["info"] = deepcopy(
|
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||||
src_dataset.meta.info.features[key].get("info", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
write_info(dst_meta.info, dst_meta.root)
|
write_info(dst_meta.info, dst_meta.root)
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import datasets
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas
|
import pandas
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow as pa
|
|
||||||
import pyarrow.dataset as pa_ds
|
import pyarrow.dataset as pa_ds
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
@@ -271,49 +270,21 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
|||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
|
|
||||||
"""Write ``table`` with one parquet row group per episode (in episode order).
|
|
||||||
|
|
||||||
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
|
|
||||||
mirroring the recording writer. ``table`` must carry a contiguous
|
|
||||||
``episode_index`` column.
|
|
||||||
"""
|
|
||||||
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
|
|
||||||
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
|
|
||||||
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
|
|
||||||
try:
|
|
||||||
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
|
|
||||||
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
|
|
||||||
finally:
|
|
||||||
writer.close()
|
|
||||||
|
|
||||||
|
|
||||||
def to_parquet_with_hf_images(
|
def to_parquet_with_hf_images(
|
||||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
|
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||||
|
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||||
|
|
||||||
Images are embedded into the arrow table first (``ParquetWriter.write_table``
|
Args:
|
||||||
does not embed external image files like ``Dataset.to_parquet`` does).
|
df: DataFrame to write to parquet.
|
||||||
``features`` types image columns as ``Image()`` in the parquet schema.
|
path: Path to write the parquet file.
|
||||||
|
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||||
|
are properly typed as Image() in the parquet schema.
|
||||||
"""
|
"""
|
||||||
|
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||||
ds = embed_images(ds)
|
ds.to_parquet(path)
|
||||||
table = ds.with_format("arrow")[:]
|
|
||||||
if "episode_index" in table.column_names:
|
|
||||||
write_table_one_row_group_per_episode(table, path)
|
|
||||||
else:
|
|
||||||
# No episode boundaries to align row groups to — keep a single write.
|
|
||||||
pq.write_table(table, str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
|
|
||||||
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
|
|
||||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
|
||||||
if "episode_index" in table.column_names:
|
|
||||||
write_table_one_row_group_per_episode(table, path)
|
|
||||||
else:
|
|
||||||
pq.write_table(table, str(path))
|
|
||||||
|
|
||||||
|
|
||||||
def item_to_torch(item: dict) -> dict:
|
def item_to_torch(item: dict) -> dict:
|
||||||
|
|||||||
@@ -201,6 +201,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self._requested_root = Path(root) if root else None
|
self._requested_root = Path(root) if root else None
|
||||||
|
self.reader = None
|
||||||
|
self.set_image_transforms(image_transforms)
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
@@ -247,7 +249,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
return_uint8=self._return_uint8,
|
return_uint8=self._return_uint8,
|
||||||
)
|
)
|
||||||
self.image_transforms = image_transforms
|
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
if force_cache_sync or not self.reader.try_load():
|
if force_cache_sync or not self.reader.try_load():
|
||||||
@@ -504,14 +505,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
"""Replace the transform applied to visual observations."""
|
"""Replace the transform applied to visual observations."""
|
||||||
self._ensure_reader().set_image_transforms(image_transforms)
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
|
if self.reader is not None:
|
||||||
|
self.reader._image_transforms = image_transforms
|
||||||
|
|
||||||
def clear_image_transforms(self) -> None:
|
def clear_image_transforms(self) -> None:
|
||||||
"""Remove the transform applied to visual observations."""
|
"""Remove the transform applied to visual observations."""
|
||||||
if self.reader is not None:
|
self.set_image_transforms(None)
|
||||||
self.reader.set_image_transforms(None)
|
|
||||||
self.image_transforms = None
|
|
||||||
|
|
||||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -70,21 +70,19 @@ def aggregate_pipeline_dataset_features(
|
|||||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||||
*,
|
*,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
exclude_images: bool = False,
|
|
||||||
patterns: Sequence[str] | None = None,
|
patterns: Sequence[str] | None = None,
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||||
|
|
||||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||||
(image or state), filters them based on `exclude_images` and `patterns`, and finally
|
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
||||||
formats them for use with a Hugging Face LeRobot Dataset.
|
formats them for use with a Hugging Face LeRobot Dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pipeline: The DataProcessorPipeline to apply.
|
pipeline: The DataProcessorPipeline to apply.
|
||||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||||
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
|
use_videos: If False, image features are excluded.
|
||||||
exclude_images: If True, image features are dropped entirely from the output.
|
|
||||||
patterns: A sequence of regex patterns to filter action and state features.
|
patterns: A sequence of regex patterns to filter action and state features.
|
||||||
Image features are not affected by this filter.
|
Image features are not affected by this filter.
|
||||||
|
|
||||||
@@ -122,7 +120,7 @@ def aggregate_pipeline_dataset_features(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2. Apply filtering rules.
|
# 2. Apply filtering rules.
|
||||||
if is_image and exclude_images:
|
if is_image and not use_videos:
|
||||||
continue
|
continue
|
||||||
if not is_image and not should_keep(key, compiled_patterns):
|
if not is_image and not should_keep(key, compiled_patterns):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
|
|||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
import safetensors
|
import safetensors
|
||||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||||
@@ -129,10 +129,43 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
if not getattr(cls, "name", None):
|
if not getattr(cls, "name", None):
|
||||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path) -> None:
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: str | Path,
|
||||||
|
*,
|
||||||
|
state_dict: dict[str, Tensor] | None = None,
|
||||||
|
repo_id: str | None = None,
|
||||||
|
push_to_hub: bool = False,
|
||||||
|
card_kwargs: dict | None = None,
|
||||||
|
**push_to_hub_kwargs,
|
||||||
|
) -> str | None:
|
||||||
|
"""Save the policy to a directory (and optionally push to the Hub).
|
||||||
|
|
||||||
|
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
|
||||||
|
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
|
||||||
|
return sharded tensors, so the caller gathers the full state dict via a cross-rank
|
||||||
|
collective and passes it here for `_save_pretrained` to write directly.
|
||||||
|
"""
|
||||||
|
save_directory = Path(save_directory)
|
||||||
|
save_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._save_pretrained(save_directory, state_dict=state_dict)
|
||||||
|
if push_to_hub:
|
||||||
|
if repo_id is None:
|
||||||
|
repo_id = save_directory.name
|
||||||
|
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
|
||||||
self.config._save_pretrained(save_directory)
|
self.config._save_pretrained(save_directory)
|
||||||
model_to_save = self.module if hasattr(self, "module") else self
|
model_to_save = self.module if hasattr(self, "module") else self
|
||||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
if state_dict is None:
|
||||||
|
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||||
|
return
|
||||||
|
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
|
||||||
|
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
|
||||||
|
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
|
||||||
|
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
|
||||||
|
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
@@ -270,6 +303,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
self,
|
self,
|
||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
peft_model=None,
|
peft_model=None,
|
||||||
|
state_dict: dict[str, Tensor] | None = None,
|
||||||
):
|
):
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
repo_id = api.create_repo(
|
repo_id = api.create_repo(
|
||||||
@@ -287,7 +321,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
peft_model.save_pretrained(saved_path)
|
peft_model.save_pretrained(saved_path)
|
||||||
self.config.save_pretrained(saved_path)
|
self.config.save_pretrained(saved_path)
|
||||||
else:
|
else:
|
||||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
# Calls _save_pretrained and stores model tensors
|
||||||
|
self.save_pretrained(saved_path, state_dict=state_dict)
|
||||||
|
|
||||||
card = self.generate_model_card(
|
card = self.generate_model_card(
|
||||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.bimanual import BimanualMixin
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
from lerobot.utils.decorators import check_if_not_connected
|
|
||||||
|
|
||||||
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
@@ -28,7 +27,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiOpenArmFollower(BimanualMixin, Robot):
|
class BiOpenArmFollower(Robot):
|
||||||
"""
|
"""
|
||||||
Bimanual OpenArm Follower Arms
|
Bimanual OpenArm Follower Arms
|
||||||
"""
|
"""
|
||||||
@@ -40,17 +39,15 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
|
||||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
# will only open the cameras assigned to it. Per-arm cameras are used
|
||||||
self._top_level_cam_keys = set(config.cameras)
|
# as fallback when top-level cameras are empty.
|
||||||
_collisions = self._top_level_cam_keys & set(
|
if config.cameras:
|
||||||
config.left_arm_config.cameras
|
left_cameras = config.cameras
|
||||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
right_cameras = {}
|
||||||
if _collisions:
|
else:
|
||||||
raise ValueError(
|
left_cameras = config.left_arm_config.cameras
|
||||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
right_cameras = config.right_arm_config.cameras
|
||||||
)
|
|
||||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
|
||||||
|
|
||||||
left_arm_config = OpenArmFollowerConfig(
|
left_arm_config = OpenArmFollowerConfig(
|
||||||
id=f"{config.id}_left" if config.id else None,
|
id=f"{config.id}_left" if config.id else None,
|
||||||
@@ -59,7 +56,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
|||||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||||
max_relative_target=config.left_arm_config.max_relative_target,
|
max_relative_target=config.left_arm_config.max_relative_target,
|
||||||
cameras=left_arm_cameras,
|
cameras=left_cameras,
|
||||||
side=config.left_arm_config.side,
|
side=config.left_arm_config.side,
|
||||||
can_interface=config.left_arm_config.can_interface,
|
can_interface=config.left_arm_config.can_interface,
|
||||||
use_can_fd=config.left_arm_config.use_can_fd,
|
use_can_fd=config.left_arm_config.use_can_fd,
|
||||||
@@ -78,7 +75,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
|||||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||||
max_relative_target=config.right_arm_config.max_relative_target,
|
max_relative_target=config.right_arm_config.max_relative_target,
|
||||||
cameras=config.right_arm_config.cameras,
|
cameras=right_cameras,
|
||||||
side=config.right_arm_config.side,
|
side=config.right_arm_config.side,
|
||||||
can_interface=config.right_arm_config.can_interface,
|
can_interface=config.right_arm_config.can_interface,
|
||||||
use_can_fd=config.right_arm_config.use_can_fd,
|
use_can_fd=config.right_arm_config.use_can_fd,
|
||||||
@@ -98,19 +95,22 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _motors_ft(self) -> dict[str, type]:
|
def _motors_ft(self) -> dict[str, type]:
|
||||||
|
left_arm_motors_ft = self.left_arm._motors_ft
|
||||||
|
right_arm_motors_ft = self.right_arm._motors_ft
|
||||||
|
|
||||||
|
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||||
|
# and the dataset feature names recorded during data collection.
|
||||||
return {
|
return {
|
||||||
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
|
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||||
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
|
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _cameras_ft(self) -> dict[str, tuple]:
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
out: dict[str, tuple] = {}
|
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
|
||||||
for k, v in self.left_arm._cameras_ft.items():
|
# "right_wrist"), so we merge them directly — unlike motors which need the
|
||||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
# left_/right_ prefix to disambiguate identical per-arm joint names.
|
||||||
for k, v in self.right_arm._cameras_ft.items():
|
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
|
||||||
out[f"right_{k}"] = v
|
|
||||||
return out
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
@@ -120,6 +120,27 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
|||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return self._motors_ft
|
return self._motors_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||||
|
|
||||||
|
@check_if_already_connected
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
self.left_arm.connect(calibrate)
|
||||||
|
self.right_arm.connect(calibrate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
self.left_arm.calibrate()
|
||||||
|
self.right_arm.calibrate()
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
self.left_arm.configure()
|
||||||
|
self.right_arm.configure()
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||||
@@ -127,15 +148,21 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
|||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_observation(self) -> RobotObservation:
|
def get_observation(self) -> RobotObservation:
|
||||||
obs_dict: RobotObservation = {}
|
obs_dict = {}
|
||||||
|
|
||||||
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
# Camera keys that should NOT get the arm prefix (they already have unique names)
|
||||||
for key, value in self.left_arm.get_observation().items():
|
left_cam_keys = set(self.left_arm.cameras.keys())
|
||||||
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
right_cam_keys = set(self.right_arm.cameras.keys())
|
||||||
|
|
||||||
# Add "right_" prefix
|
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||||
for key, value in self.right_arm.get_observation().items():
|
# and the dataset feature names recorded during data collection.
|
||||||
obs_dict[f"right_{key}"] = value
|
right_obs = self.right_arm.get_observation()
|
||||||
|
for key, value in right_obs.items():
|
||||||
|
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
|
||||||
|
|
||||||
|
left_obs = self.left_arm.get_observation()
|
||||||
|
for key, value in left_obs.items():
|
||||||
|
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@@ -162,4 +189,9 @@ class BiOpenArmFollower(BimanualMixin, Robot):
|
|||||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||||
|
|
||||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def disconnect(self):
|
||||||
|
self.left_arm.disconnect()
|
||||||
|
self.right_arm.disconnect()
|
||||||
|
|||||||
@@ -32,7 +32,5 @@ class BiOpenArmFollowerConfig(RobotConfig):
|
|||||||
left_arm_config: OpenArmFollowerConfigBase
|
left_arm_config: OpenArmFollowerConfigBase
|
||||||
right_arm_config: OpenArmFollowerConfigBase
|
right_arm_config: OpenArmFollowerConfigBase
|
||||||
|
|
||||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
# Top-level cameras shared across both arms.
|
||||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
|
||||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.bimanual import BimanualMixin
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
from lerobot.utils.decorators import check_if_not_connected
|
|
||||||
|
|
||||||
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
|
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
@@ -28,7 +27,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiRebotB601Follower(BimanualMixin, Robot):
|
class BiRebotB601Follower(Robot):
|
||||||
"""Bimanual Seeed Studio reBot B601-DM follower.
|
"""Bimanual Seeed Studio reBot B601-DM follower.
|
||||||
|
|
||||||
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
|
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
|
||||||
@@ -42,18 +41,6 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
|
||||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
|
||||||
self._top_level_cam_keys = set(config.cameras)
|
|
||||||
_collisions = self._top_level_cam_keys & set(
|
|
||||||
config.left_arm_config.cameras
|
|
||||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
|
||||||
if _collisions:
|
|
||||||
raise ValueError(
|
|
||||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
|
||||||
)
|
|
||||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
|
||||||
|
|
||||||
left_arm_config = RebotB601FollowerRobotConfig(
|
left_arm_config = RebotB601FollowerRobotConfig(
|
||||||
id=f"{config.id}_left" if config.id else None,
|
id=f"{config.id}_left" if config.id else None,
|
||||||
calibration_dir=config.calibration_dir,
|
calibration_dir=config.calibration_dir,
|
||||||
@@ -62,7 +49,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
|||||||
dm_serial_baud=config.left_arm_config.dm_serial_baud,
|
dm_serial_baud=config.left_arm_config.dm_serial_baud,
|
||||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||||
max_relative_target=config.left_arm_config.max_relative_target,
|
max_relative_target=config.left_arm_config.max_relative_target,
|
||||||
cameras=left_arm_cameras,
|
cameras=config.left_arm_config.cameras,
|
||||||
motor_can_ids=config.left_arm_config.motor_can_ids,
|
motor_can_ids=config.left_arm_config.motor_can_ids,
|
||||||
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
||||||
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
||||||
@@ -99,12 +86,10 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _cameras_ft(self) -> dict[str, tuple]:
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
out: dict[str, tuple] = {}
|
return {
|
||||||
for k, v in self.left_arm._cameras_ft.items():
|
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
|
||||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
|
||||||
for k, v in self.right_arm._cameras_ft.items():
|
}
|
||||||
out[f"right_{k}"] = v
|
|
||||||
return out
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
@@ -114,13 +99,32 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
|||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return self._motors_ft
|
return self._motors_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||||
|
|
||||||
|
@check_if_already_connected
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
self.left_arm.connect(calibrate)
|
||||||
|
self.right_arm.connect(calibrate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
self.left_arm.calibrate()
|
||||||
|
self.right_arm.calibrate()
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
self.left_arm.configure()
|
||||||
|
self.right_arm.configure()
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_observation(self) -> RobotObservation:
|
def get_observation(self) -> RobotObservation:
|
||||||
obs_dict: RobotObservation = {}
|
obs_dict = {}
|
||||||
for k, v in self.left_arm.get_observation().items():
|
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
|
||||||
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
|
||||||
for k, v in self.right_arm.get_observation().items():
|
|
||||||
obs_dict[f"right_{k}"] = v
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
@@ -139,3 +143,8 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
|||||||
**{f"left_{k}": v for k, v in sent_action_left.items()},
|
**{f"left_{k}": v for k, v in sent_action_left.items()},
|
||||||
**{f"right_{k}": v for k, v in sent_action_right.items()},
|
**{f"right_{k}": v for k, v in sent_action_right.items()},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
self.left_arm.disconnect()
|
||||||
|
self.right_arm.disconnect()
|
||||||
|
|||||||
@@ -14,9 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig
|
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
from ..rebot_b601_follower import RebotB601FollowerConfig
|
from ..rebot_b601_follower import RebotB601FollowerConfig
|
||||||
@@ -29,8 +27,3 @@ class BiRebotB601FollowerConfig(RobotConfig):
|
|||||||
|
|
||||||
left_arm_config: RebotB601FollowerConfig
|
left_arm_config: RebotB601FollowerConfig
|
||||||
right_arm_config: RebotB601FollowerConfig
|
right_arm_config: RebotB601FollowerConfig
|
||||||
|
|
||||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
|
||||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
|
||||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.bimanual import BimanualMixin
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
from lerobot.utils.decorators import check_if_not_connected
|
|
||||||
|
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
from ..so_follower import SOFollower, SOFollowerRobotConfig
|
from ..so_follower import SOFollower, SOFollowerRobotConfig
|
||||||
@@ -28,7 +27,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiSOFollower(BimanualMixin, Robot):
|
class BiSOFollower(Robot):
|
||||||
"""
|
"""
|
||||||
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||||
"""
|
"""
|
||||||
@@ -40,18 +39,6 @@ class BiSOFollower(BimanualMixin, Robot):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
|
||||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
|
||||||
self._top_level_cam_keys = set(config.cameras)
|
|
||||||
_collisions = self._top_level_cam_keys & set(
|
|
||||||
config.left_arm_config.cameras
|
|
||||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
|
||||||
if _collisions:
|
|
||||||
raise ValueError(
|
|
||||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
|
||||||
)
|
|
||||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
|
||||||
|
|
||||||
left_arm_config = SOFollowerRobotConfig(
|
left_arm_config = SOFollowerRobotConfig(
|
||||||
id=f"{config.id}_left" if config.id else None,
|
id=f"{config.id}_left" if config.id else None,
|
||||||
calibration_dir=config.calibration_dir,
|
calibration_dir=config.calibration_dir,
|
||||||
@@ -59,7 +46,7 @@ class BiSOFollower(BimanualMixin, Robot):
|
|||||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||||
max_relative_target=config.left_arm_config.max_relative_target,
|
max_relative_target=config.left_arm_config.max_relative_target,
|
||||||
use_degrees=config.left_arm_config.use_degrees,
|
use_degrees=config.left_arm_config.use_degrees,
|
||||||
cameras=left_arm_cameras,
|
cameras=config.left_arm_config.cameras,
|
||||||
)
|
)
|
||||||
|
|
||||||
right_arm_config = SOFollowerRobotConfig(
|
right_arm_config = SOFollowerRobotConfig(
|
||||||
@@ -90,12 +77,13 @@ class BiSOFollower(BimanualMixin, Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _cameras_ft(self) -> dict[str, tuple]:
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
out: dict[str, tuple] = {}
|
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||||
for k, v in self.left_arm._cameras_ft.items():
|
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
|
||||||
for k, v in self.right_arm._cameras_ft.items():
|
return {
|
||||||
out[f"right_{k}"] = v
|
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||||
return out
|
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||||
|
}
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
@@ -105,21 +93,42 @@ class BiSOFollower(BimanualMixin, Robot):
|
|||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return self._motors_ft
|
return self._motors_ft
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||||
|
|
||||||
|
@check_if_already_connected
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
self.left_arm.connect(calibrate)
|
||||||
|
self.right_arm.connect(calibrate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
self.left_arm.calibrate()
|
||||||
|
self.right_arm.calibrate()
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
self.left_arm.configure()
|
||||||
|
self.right_arm.configure()
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
self.left_arm.setup_motors()
|
self.left_arm.setup_motors()
|
||||||
self.right_arm.setup_motors()
|
self.right_arm.setup_motors()
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_observation(self) -> RobotObservation:
|
def get_observation(self) -> RobotObservation:
|
||||||
obs_dict: RobotObservation = {}
|
obs_dict = {}
|
||||||
|
|
||||||
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
# Add "left_" prefix
|
||||||
for key, value in self.left_arm.get_observation().items():
|
left_obs = self.left_arm.get_observation()
|
||||||
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||||
|
|
||||||
# Add "right_" prefix
|
# Add "right_" prefix
|
||||||
for key, value in self.right_arm.get_observation().items():
|
right_obs = self.right_arm.get_observation()
|
||||||
obs_dict[f"right_{key}"] = value
|
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@@ -142,3 +151,8 @@ class BiSOFollower(BimanualMixin, Robot):
|
|||||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||||
|
|
||||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def disconnect(self):
|
||||||
|
self.left_arm.disconnect()
|
||||||
|
self.right_arm.disconnect()
|
||||||
|
|||||||
@@ -14,9 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig
|
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
from ..so_follower import SOFollowerConfig
|
from ..so_follower import SOFollowerConfig
|
||||||
@@ -29,8 +27,3 @@ class BiSOFollowerConfig(RobotConfig):
|
|||||||
|
|
||||||
left_arm_config: SOFollowerConfig
|
left_arm_config: SOFollowerConfig
|
||||||
right_arm_config: SOFollowerConfig
|
right_arm_config: SOFollowerConfig
|
||||||
|
|
||||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
|
||||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
|
||||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
bi_openarm_mini,
|
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
homunculus,
|
homunculus,
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ from lerobot.robots import ( # noqa: F401
|
|||||||
from lerobot.teleoperators import ( # noqa: F401
|
from lerobot.teleoperators import ( # noqa: F401
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
bi_openarm_mini,
|
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
gamepad,
|
gamepad,
|
||||||
|
|||||||
@@ -137,7 +137,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
bi_openarm_mini,
|
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
homunculus,
|
homunculus,
|
||||||
|
|||||||
@@ -174,7 +174,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
bi_openarm_mini,
|
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
homunculus,
|
homunculus,
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from lerobot.robots import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
from lerobot.teleoperators import ( # noqa: F401
|
from lerobot.teleoperators import ( # noqa: F401
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_mini,
|
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
koch_leader,
|
koch_leader,
|
||||||
|
|||||||
@@ -89,7 +89,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
bi_openarm_mini,
|
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
gamepad,
|
gamepad,
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
|
|
||||||
require_package("accelerate", extra="training")
|
require_package("accelerate", extra="training")
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
|
||||||
|
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
|
|
||||||
@@ -197,8 +198,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||||
# We set find_unused_parameters=True to handle models with conditional computation
|
# We set find_unused_parameters=True to handle models with conditional computation
|
||||||
if accelerator is None:
|
if accelerator is None:
|
||||||
from accelerate.utils import DistributedDataParallelKwargs
|
|
||||||
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||||
@@ -558,20 +557,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
if cfg.save_checkpoint and is_saving_step:
|
if cfg.save_checkpoint and is_saving_step:
|
||||||
|
# All ranks must call get_state_dict; rank 0 gets the
|
||||||
|
# full state dict, others get an empty dict.
|
||||||
|
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
||||||
|
model_state_dict = accelerator.get_state_dict(policy)
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||||
|
if is_fsdp:
|
||||||
|
# TODO(fsdp): sharded optimizer-state save/resume is not wired up yet.
|
||||||
|
logging.warning(
|
||||||
|
"FSDP checkpoint: saving model weights only (optimizer state skipped; "
|
||||||
|
"resume-from-checkpoint not supported under FSDP yet)."
|
||||||
|
)
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
step=step,
|
step=step,
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
policy=accelerator.unwrap_model(policy),
|
policy=accelerator.unwrap_model(policy),
|
||||||
optimizer=optimizer,
|
optimizer=None if is_fsdp else optimizer,
|
||||||
scheduler=lr_scheduler,
|
scheduler=lr_scheduler,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
num_processes=accelerator.num_processes,
|
num_processes=accelerator.num_processes,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
)
|
)
|
||||||
update_last_checkpoint(checkpoint_dir)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
@@ -634,6 +644,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
if eval_env:
|
if eval_env:
|
||||||
close_envs(eval_env)
|
close_envs(eval_env)
|
||||||
|
|
||||||
|
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
||||||
|
model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
@@ -643,7 +655,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||||
else:
|
else:
|
||||||
unwrapped_model.push_model_to_hub(cfg)
|
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
|
||||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||||
|
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction
|
from lerobot.types import RobotAction
|
||||||
from lerobot.utils.bimanual import BimanualMixin
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
from lerobot.utils.decorators import check_if_not_connected
|
|
||||||
|
|
||||||
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
||||||
from ..teleoperator import Teleoperator
|
from ..teleoperator import Teleoperator
|
||||||
@@ -28,7 +27,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiOpenArmLeader(BimanualMixin, Teleoperator):
|
class BiOpenArmLeader(Teleoperator):
|
||||||
"""
|
"""
|
||||||
Bimanual OpenArm Leader Arms
|
Bimanual OpenArm Leader Arms
|
||||||
"""
|
"""
|
||||||
@@ -87,6 +86,27 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
|
|||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||||
|
|
||||||
|
@check_if_already_connected
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
self.left_arm.connect(calibrate)
|
||||||
|
self.right_arm.connect(calibrate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
self.left_arm.calibrate()
|
||||||
|
self.right_arm.calibrate()
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
self.left_arm.configure()
|
||||||
|
self.right_arm.configure()
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||||
@@ -109,3 +129,8 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
|
|||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
# TODO: Implement force feedback
|
# TODO: Implement force feedback
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
self.left_arm.disconnect()
|
||||||
|
self.right_arm.disconnect()
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
|
|||||||
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
||||||
@dataclass
|
@dataclass
|
||||||
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
||||||
"""Configuration class for Bi OpenArm Leader teleoperators."""
|
"""Configuration class for Bi OpenArm Follower robots."""
|
||||||
|
|
||||||
left_arm_config: OpenArmLeaderConfigBase
|
left_arm_config: OpenArmLeaderConfigBase
|
||||||
right_arm_config: OpenArmLeaderConfigBase
|
right_arm_config: OpenArmLeaderConfigBase
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from .bi_openarm_mini import BiOpenArmMini
|
|
||||||
from .config_bi_openarm_mini import BiOpenArmMiniConfig
|
|
||||||
|
|
||||||
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import cached_property
|
|
||||||
|
|
||||||
from lerobot.types import RobotAction
|
|
||||||
from lerobot.utils.bimanual import BimanualMixin
|
|
||||||
from lerobot.utils.decorators import check_if_not_connected
|
|
||||||
|
|
||||||
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
|
|
||||||
from ..teleoperator import Teleoperator
|
|
||||||
from .config_bi_openarm_mini import BiOpenArmMiniConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BiOpenArmMini(BimanualMixin, Teleoperator):
|
|
||||||
"""Bimanual OpenArm Mini teleoperator.
|
|
||||||
|
|
||||||
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
|
|
||||||
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
|
|
||||||
bimanual leader can teleoperate a bimanual OpenArm follower.
|
|
||||||
"""
|
|
||||||
|
|
||||||
config_class = BiOpenArmMiniConfig
|
|
||||||
name = "bi_openarm_mini"
|
|
||||||
|
|
||||||
def __init__(self, config: BiOpenArmMiniConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# `side` is forced to match left/right regardless of what the user passed
|
|
||||||
# on the per-arm base config — the bimanual wrapper owns the side semantics.
|
|
||||||
left_arm_config = OpenArmMiniConfig(
|
|
||||||
id=f"{config.id}_left" if config.id else None,
|
|
||||||
calibration_dir=config.calibration_dir,
|
|
||||||
port=config.left_arm_config.port,
|
|
||||||
side="left",
|
|
||||||
use_degrees=config.left_arm_config.use_degrees,
|
|
||||||
)
|
|
||||||
|
|
||||||
right_arm_config = OpenArmMiniConfig(
|
|
||||||
id=f"{config.id}_right" if config.id else None,
|
|
||||||
calibration_dir=config.calibration_dir,
|
|
||||||
port=config.right_arm_config.port,
|
|
||||||
side="right",
|
|
||||||
use_degrees=config.right_arm_config.use_degrees,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.left_arm = OpenArmMini(left_arm_config)
|
|
||||||
self.right_arm = OpenArmMini(right_arm_config)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def action_features(self) -> dict[str, type]:
|
|
||||||
return {
|
|
||||||
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
|
|
||||||
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
|
|
||||||
}
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def feedback_features(self) -> dict[str, type]:
|
|
||||||
return {
|
|
||||||
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
|
|
||||||
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
|
|
||||||
}
|
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
|
||||||
self.left_arm.setup_motors()
|
|
||||||
self.right_arm.setup_motors()
|
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def get_action(self) -> RobotAction:
|
|
||||||
action: RobotAction = {}
|
|
||||||
for k, v in self.left_arm.get_action().items():
|
|
||||||
action[f"left_{k}"] = v
|
|
||||||
for k, v in self.right_arm.get_action().items():
|
|
||||||
action[f"right_{k}"] = v
|
|
||||||
return action
|
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
|
||||||
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
|
|
||||||
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
|
|
||||||
if left_fb:
|
|
||||||
self.left_arm.send_feedback(left_fb)
|
|
||||||
if right_fb:
|
|
||||||
self.right_arm.send_feedback(right_fb)
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from ..config import TeleoperatorConfig
|
|
||||||
from ..openarm_mini import OpenArmMiniConfigBase
|
|
||||||
|
|
||||||
|
|
||||||
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
|
|
||||||
@dataclass
|
|
||||||
class BiOpenArmMiniConfig(TeleoperatorConfig):
|
|
||||||
"""Configuration class for Bi OpenArm Mini teleoperators."""
|
|
||||||
|
|
||||||
left_arm_config: OpenArmMiniConfigBase
|
|
||||||
right_arm_config: OpenArmMiniConfigBase
|
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .bi_rebot_102_leader import BiRebot102Leader
|
from .bi_rebot_102_leader import BiRebotArm102Leader
|
||||||
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
|
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
|
||||||
|
|
||||||
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
|
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
|
||||||
|
|||||||
@@ -18,17 +18,16 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction
|
from lerobot.types import RobotAction
|
||||||
from lerobot.utils.bimanual import BimanualMixin
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
from lerobot.utils.decorators import check_if_not_connected
|
|
||||||
|
|
||||||
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
|
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
|
||||||
from ..teleoperator import Teleoperator
|
from ..teleoperator import Teleoperator
|
||||||
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
|
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiRebot102Leader(BimanualMixin, Teleoperator):
|
class BiRebotArm102Leader(Teleoperator):
|
||||||
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
|
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
|
||||||
|
|
||||||
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
|
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
|
||||||
@@ -36,10 +35,10 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
|
|||||||
leader can teleoperate a bimanual reBot B601 follower.
|
leader can teleoperate a bimanual reBot B601 follower.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = BiRebot102LeaderConfig
|
config_class = BiRebotArm102LeaderConfig
|
||||||
name = "bi_rebot_102_leader"
|
name = "bi_rebot_102_leader"
|
||||||
|
|
||||||
def __init__(self, config: BiRebot102LeaderConfig):
|
def __init__(self, config: BiRebotArm102LeaderConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@@ -77,6 +76,27 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
|
|||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||||
|
|
||||||
|
@check_if_already_connected
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
self.left_arm.connect(calibrate)
|
||||||
|
self.right_arm.connect(calibrate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
self.left_arm.calibrate()
|
||||||
|
self.right_arm.calibrate()
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
self.left_arm.configure()
|
||||||
|
self.right_arm.configure()
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_action(self) -> RobotAction:
|
def get_action(self) -> RobotAction:
|
||||||
action_dict = {}
|
action_dict = {}
|
||||||
@@ -86,3 +106,8 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
|
|||||||
|
|
||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
|
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
self.left_arm.disconnect()
|
||||||
|
self.right_arm.disconnect()
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
|
|||||||
|
|
||||||
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
|
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
|
||||||
@dataclass
|
@dataclass
|
||||||
class BiRebot102LeaderConfig(TeleoperatorConfig):
|
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
|
||||||
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
|
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
|
||||||
|
|
||||||
left_arm_config: RebotArm102LeaderConfig
|
left_arm_config: RebotArm102LeaderConfig
|
||||||
|
|||||||
@@ -17,9 +17,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
from lerobot.utils.bimanual import BimanualMixin
|
|
||||||
from lerobot.utils.decorators import check_if_not_connected
|
|
||||||
|
|
||||||
from ..so_leader import SOLeader, SOLeaderTeleopConfig
|
from ..so_leader import SOLeader, SOLeaderTeleopConfig
|
||||||
from ..teleoperator import Teleoperator
|
from ..teleoperator import Teleoperator
|
||||||
@@ -28,7 +26,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiSOLeader(BimanualMixin, Teleoperator):
|
class BiSOLeader(Teleoperator):
|
||||||
"""
|
"""
|
||||||
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||||
"""
|
"""
|
||||||
@@ -69,12 +67,33 @@ class BiSOLeader(BimanualMixin, Teleoperator):
|
|||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||||
|
|
||||||
|
@check_if_already_connected
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
self.left_arm.connect(calibrate)
|
||||||
|
self.right_arm.connect(calibrate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
self.left_arm.calibrate()
|
||||||
|
self.right_arm.calibrate()
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
self.left_arm.configure()
|
||||||
|
self.right_arm.configure()
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
self.left_arm.setup_motors()
|
self.left_arm.setup_motors()
|
||||||
self.right_arm.setup_motors()
|
self.right_arm.setup_motors()
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_action(self) -> RobotAction:
|
def get_action(self) -> dict[str, float]:
|
||||||
action_dict = {}
|
action_dict = {}
|
||||||
|
|
||||||
# Add "left_" prefix
|
# Add "left_" prefix
|
||||||
@@ -90,3 +109,8 @@ class BiSOLeader(BimanualMixin, Teleoperator):
|
|||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
# TODO: Implement force feedback
|
# TODO: Implement force feedback
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
self.left_arm.disconnect()
|
||||||
|
self.right_arm.disconnect()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
|
from .config_openarm_mini import OpenArmMiniConfig
|
||||||
from .openarm_mini import OpenArmMini
|
from .openarm_mini import OpenArmMini
|
||||||
|
|
||||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
|
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||||
|
|||||||
@@ -19,21 +19,12 @@ from dataclasses import dataclass
|
|||||||
from ..config import TeleoperatorConfig
|
from ..config import TeleoperatorConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OpenArmMiniConfigBase:
|
|
||||||
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
|
|
||||||
|
|
||||||
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
|
|
||||||
port: str
|
|
||||||
|
|
||||||
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
|
|
||||||
# during readout. If `None`, no flipping is applied.
|
|
||||||
side: str | None = None
|
|
||||||
|
|
||||||
use_degrees: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
|
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||||
pass
|
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||||
|
|
||||||
|
port_right: str = "/dev/ttyUSB0"
|
||||||
|
port_left: str = "/dev/ttyUSB1"
|
||||||
|
|
||||||
|
use_degrees: bool = True
|
||||||
|
|||||||
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Per-side motor direction flips applied during readout.
|
# Motors whose direction is inverted during readout
|
||||||
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
|
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
|
||||||
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
|
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||||
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Leader joint 6 ↔ follower joint 7 (symmetric — its own inverse).
|
# Leader joint 6 maps to follower joint 7 and vice versa
|
||||||
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
|
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
|
||||||
|
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
|
||||||
|
|
||||||
GRIPPER_TELEOP_TO_DEGREES = -0.65
|
GRIPPER_TELEOP_TO_DEGREES = -0.65
|
||||||
|
|
||||||
|
|
||||||
class OpenArmMini(Teleoperator):
|
class OpenArmMini(Teleoperator):
|
||||||
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
|
"""
|
||||||
|
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||||
|
|
||||||
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
|
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = OpenArmMiniConfig
|
config_class = OpenArmMiniConfig
|
||||||
@@ -56,12 +56,9 @@ class OpenArmMini(Teleoperator):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
|
|
||||||
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
|
|
||||||
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
|
|
||||||
|
|
||||||
norm_mode_body = MotorNormMode.DEGREES
|
norm_mode_body = MotorNormMode.DEGREES
|
||||||
motors = {
|
|
||||||
|
motors_right = {
|
||||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||||
@@ -72,15 +69,46 @@ class OpenArmMini(Teleoperator):
|
|||||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.bus = FeetechMotorsBus(
|
motors_left = {
|
||||||
port=self.config.port,
|
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||||
motors=motors,
|
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||||
calibration=self.calibration,
|
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||||
|
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||||
|
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||||
|
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||||
|
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||||
|
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||||
|
}
|
||||||
|
|
||||||
|
cal_right = {
|
||||||
|
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||||
|
}
|
||||||
|
cal_left = {
|
||||||
|
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||||
|
}
|
||||||
|
|
||||||
|
self.bus_right = FeetechMotorsBus(
|
||||||
|
port=self.config.port_right,
|
||||||
|
motors=motors_right,
|
||||||
|
calibration=cal_right,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bus_left = FeetechMotorsBus(
|
||||||
|
port=self.config.port_left,
|
||||||
|
motors=motors_left,
|
||||||
|
calibration=cal_left,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return {f"{motor}.pos": float for motor in self.bus.motors}
|
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
|
||||||
|
# and the dataset feature names recorded during data collection.
|
||||||
|
features: dict[str, type] = {}
|
||||||
|
for motor in self.bus_right.motors:
|
||||||
|
features[f"right_{motor}.pos"] = float
|
||||||
|
for motor in self.bus_left.motors:
|
||||||
|
features[f"left_{motor}.pos"] = float
|
||||||
|
return features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
@@ -88,12 +116,14 @@ class OpenArmMini(Teleoperator):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self.bus.is_connected
|
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||||
|
|
||||||
@check_if_already_connected
|
@check_if_already_connected
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
logger.info(f"Connecting arm on {self.config.port}...")
|
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||||
self.bus.connect()
|
self.bus_right.connect()
|
||||||
|
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||||
|
self.bus_left.connect()
|
||||||
|
|
||||||
if calibrate:
|
if calibrate:
|
||||||
self.calibrate()
|
self.calibrate()
|
||||||
@@ -103,14 +133,14 @@ class OpenArmMini(Teleoperator):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_calibrated(self) -> bool:
|
def is_calibrated(self) -> bool:
|
||||||
return self.bus.is_calibrated
|
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
def calibrate(self) -> None:
|
||||||
"""
|
"""
|
||||||
Run calibration procedure for a single OpenArm Mini arm.
|
Run calibration procedure for OpenArm Mini.
|
||||||
|
|
||||||
1. Disable torque
|
1. Disable torque
|
||||||
2. Ask user to position arm in hanging position with gripper closed
|
2. Ask user to position arms in hanging position with grippers closed
|
||||||
3. Set this as zero position via half-turn homing
|
3. Set this as zero position via half-turn homing
|
||||||
4. Interactive gripper calibration (open/close positions)
|
4. Interactive gripper calibration (open/close positions)
|
||||||
5. Save calibration
|
5. Save calibration
|
||||||
@@ -122,51 +152,70 @@ class OpenArmMini(Teleoperator):
|
|||||||
)
|
)
|
||||||
if user_input.strip().lower() != "c":
|
if user_input.strip().lower() != "c":
|
||||||
logger.info(f"Using existing calibration for {self.id}")
|
logger.info(f"Using existing calibration for {self.id}")
|
||||||
self.bus.write_calibration(self.calibration)
|
cal_right = {
|
||||||
|
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||||
|
}
|
||||||
|
cal_left = {
|
||||||
|
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||||
|
}
|
||||||
|
self.bus_right.write_calibration(cal_right)
|
||||||
|
self.bus_left.write_calibration(cal_left)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"\nRunning calibration for {self}")
|
logger.info(f"\nRunning calibration for {self}")
|
||||||
|
|
||||||
self.bus.disable_torque()
|
self._calibrate_arm("right", self.bus_right)
|
||||||
|
self._calibrate_arm("left", self.bus_left)
|
||||||
|
|
||||||
logger.info("Setting Phase to 12 for all motors...")
|
self._save_calibration()
|
||||||
for motor in self.bus.motors:
|
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||||
self.bus.write("Phase", motor, 12)
|
|
||||||
|
|
||||||
for motor in self.bus.motors:
|
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||||
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
"""Calibrate a single arm with Feetech motors."""
|
||||||
|
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||||
|
|
||||||
|
bus.disable_torque()
|
||||||
|
|
||||||
|
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||||
|
for motor in bus.motors:
|
||||||
|
bus.write("Phase", motor, 12)
|
||||||
|
|
||||||
|
for motor in bus.motors:
|
||||||
|
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||||
|
|
||||||
input(
|
input(
|
||||||
"\nCalibration: Zero Position\n"
|
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||||
"Position the arm in the following configuration:\n"
|
"Position the arm in the following configuration:\n"
|
||||||
" - Arm hanging straight down\n"
|
" - Arm hanging straight down\n"
|
||||||
" - Gripper closed\n"
|
" - Gripper closed\n"
|
||||||
"Press ENTER when ready..."
|
"Press ENTER when ready..."
|
||||||
)
|
)
|
||||||
|
|
||||||
homing_offsets = self.bus.set_half_turn_homings()
|
homing_offsets = bus.set_half_turn_homings()
|
||||||
logger.info("Arm zero position set.")
|
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||||
|
|
||||||
print("\nSetting motor ranges\n")
|
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||||
|
|
||||||
if self.calibration is None:
|
if self.calibration is None:
|
||||||
self.calibration = {}
|
self.calibration = {}
|
||||||
|
|
||||||
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
|
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||||
max_res = motor_resolution - 1
|
max_res = motor_resolution - 1
|
||||||
|
|
||||||
for motor_name, motor in self.bus.motors.items():
|
for motor_name, motor in bus.motors.items():
|
||||||
|
prefixed_name = f"{arm_name}_{motor_name}"
|
||||||
|
|
||||||
if motor_name == "gripper":
|
if motor_name == "gripper":
|
||||||
input(
|
input(
|
||||||
"\nGripper Calibration\n"
|
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||||
"Step 1: CLOSE the gripper fully\n"
|
f"Step 1: CLOSE the gripper fully\n"
|
||||||
"Press ENTER when gripper is closed..."
|
f"Press ENTER when gripper is closed..."
|
||||||
)
|
)
|
||||||
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
|
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||||
|
|
||||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||||
open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
|
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||||
|
|
||||||
if closed_pos < open_pos:
|
if closed_pos < open_pos:
|
||||||
@@ -179,16 +228,16 @@ class OpenArmMini(Teleoperator):
|
|||||||
drive_mode = 1
|
drive_mode = 1
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f" {motor_name}: range set to [{range_min}, {range_max}] "
|
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
range_min = 0
|
range_min = 0
|
||||||
range_max = max_res
|
range_max = max_res
|
||||||
drive_mode = 0
|
drive_mode = 0
|
||||||
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
|
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||||
|
|
||||||
self.calibration[motor_name] = MotorCalibration(
|
self.calibration[prefixed_name] = MotorCalibration(
|
||||||
id=motor.id,
|
id=motor.id,
|
||||||
drive_mode=drive_mode,
|
drive_mode=drive_mode,
|
||||||
homing_offset=homing_offsets[motor_name],
|
homing_offset=homing_offsets[motor_name],
|
||||||
@@ -196,68 +245,108 @@ class OpenArmMini(Teleoperator):
|
|||||||
range_max=range_max,
|
range_max=range_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bus.write_calibration(self.calibration)
|
cal_for_bus = {
|
||||||
self._save_calibration()
|
k.replace(f"{arm_name}_", ""): v
|
||||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
for k, v in self.calibration.items()
|
||||||
|
if k.startswith(f"{arm_name}_")
|
||||||
|
}
|
||||||
|
bus.write_calibration(cal_for_bus)
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
self.bus.disable_torque()
|
self.bus_right.disable_torque()
|
||||||
self.bus.configure_motors()
|
self.bus_right.configure_motors()
|
||||||
for motor in self.bus.motors:
|
for motor in self.bus_right.motors:
|
||||||
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||||
|
|
||||||
|
self.bus_left.disable_torque()
|
||||||
|
self.bus_left.configure_motors()
|
||||||
|
for motor in self.bus_left.motors:
|
||||||
|
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
for motor in reversed(self.bus.motors):
|
print("\nSetting up RIGHT arm motors...")
|
||||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
for motor in reversed(self.bus_right.motors):
|
||||||
self.bus.setup_motor(motor)
|
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
self.bus_right.setup_motor(motor)
|
||||||
|
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||||
|
|
||||||
|
print("\nSetting up LEFT arm motors...")
|
||||||
|
for motor in reversed(self.bus_left.motors):
|
||||||
|
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||||
|
self.bus_left.setup_motor(motor)
|
||||||
|
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_action(self) -> RobotAction:
|
def get_action(self) -> RobotAction:
|
||||||
"""Get current action (read positions from all motors)."""
|
"""Get current action from both arms (read positions from all motors)."""
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
positions = self.bus.sync_read("Present_Position")
|
right_positions = self.bus_right.sync_read("Present_Position")
|
||||||
|
left_positions = self.bus_left.sync_read("Present_Position")
|
||||||
|
|
||||||
|
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
|
||||||
|
# and the dataset feature names recorded during data collection.
|
||||||
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
|
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
|
||||||
# Per-side direction flip is applied based on the configured `side`.
|
|
||||||
action: dict[str, Any] = {}
|
action: dict[str, Any] = {}
|
||||||
for motor, val in positions.items():
|
for motor, val in right_positions.items():
|
||||||
target = JOINT_REMAP.get(motor, motor)
|
target = JOINT_REMAP.get(motor, motor)
|
||||||
if motor == "gripper":
|
if motor == "gripper":
|
||||||
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
|
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
|
||||||
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||||
else:
|
else:
|
||||||
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
|
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||||
|
for motor, val in left_positions.items():
|
||||||
|
target = JOINT_REMAP.get(motor, motor)
|
||||||
|
if motor == "gripper":
|
||||||
|
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||||
|
else:
|
||||||
|
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||||
|
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def enable_torque(self) -> None:
|
def enable_torque(self) -> None:
|
||||||
self.bus.enable_torque()
|
"""Enable torque on both arms for position control."""
|
||||||
|
self.bus_right.enable_torque()
|
||||||
|
self.bus_left.enable_torque()
|
||||||
|
|
||||||
def disable_torque(self) -> None:
|
def disable_torque(self) -> None:
|
||||||
self.bus.disable_torque()
|
"""Disable torque on both arms for free movement."""
|
||||||
|
self.bus_right.disable_torque()
|
||||||
|
self.bus_left.disable_torque()
|
||||||
|
|
||||||
def write_goal_positions(self, positions: dict[str, float]) -> None:
|
def write_goal_positions(self, positions: dict[str, float]) -> None:
|
||||||
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
|
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
|
||||||
goals: dict[str, float] = {}
|
right_goals: dict[str, float] = {}
|
||||||
|
left_goals: dict[str, float] = {}
|
||||||
|
|
||||||
for key, val in positions.items():
|
for key, val in positions.items():
|
||||||
if not key.endswith(".pos"):
|
if not key.endswith(".pos"):
|
||||||
continue
|
continue
|
||||||
base = key.removesuffix(".pos")
|
motor_name = key.removesuffix(".pos")
|
||||||
# JOINT_REMAP is symmetric (its own inverse).
|
if motor_name.startswith("right_"):
|
||||||
target = JOINT_REMAP.get(base, base)
|
base = motor_name.removeprefix("right_")
|
||||||
if base == "gripper":
|
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
|
||||||
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
|
target = JOINT_REMAP_REVERSE.get(base, base)
|
||||||
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
if base == "gripper":
|
||||||
else:
|
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
|
||||||
# Un-flip using the ORIGINAL motor name (target = leader motor)
|
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||||
goals[target] = -val if target in self._motors_to_flip else val
|
else:
|
||||||
|
# Un-flip using the ORIGINAL motor name (target = leader motor)
|
||||||
|
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
|
||||||
|
elif motor_name.startswith("left_"):
|
||||||
|
base = motor_name.removeprefix("left_")
|
||||||
|
target = JOINT_REMAP_REVERSE.get(base, base)
|
||||||
|
if base == "gripper":
|
||||||
|
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||||
|
else:
|
||||||
|
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
|
||||||
|
|
||||||
if goals:
|
if right_goals:
|
||||||
self.bus.sync_write("Goal_Position", goals)
|
self.bus_right.sync_write("Goal_Position", right_goals)
|
||||||
|
if left_goals:
|
||||||
|
self.bus_left.sync_write("Goal_Position", left_goals)
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
@@ -265,5 +354,6 @@ class OpenArmMini(Teleoperator):
|
|||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
self.bus.disconnect()
|
self.bus_right.disconnect()
|
||||||
|
self.bus_left.disconnect()
|
||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|||||||
@@ -99,18 +99,14 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
|||||||
from .openarm_mini import OpenArmMini
|
from .openarm_mini import OpenArmMini
|
||||||
|
|
||||||
return OpenArmMini(config)
|
return OpenArmMini(config)
|
||||||
elif config.type == "bi_openarm_mini":
|
|
||||||
from .bi_openarm_mini import BiOpenArmMini
|
|
||||||
|
|
||||||
return BiOpenArmMini(config)
|
|
||||||
elif config.type == "rebot_102_leader":
|
elif config.type == "rebot_102_leader":
|
||||||
from .rebot_102_leader import RebotArm102Leader
|
from .rebot_102_leader import RebotArm102Leader
|
||||||
|
|
||||||
return RebotArm102Leader(config)
|
return RebotArm102Leader(config)
|
||||||
elif config.type == "bi_rebot_102_leader":
|
elif config.type == "bi_rebot_102_leader":
|
||||||
from .bi_rebot_102_leader import BiRebot102Leader
|
from .bi_rebot_102_leader import BiRebotArm102Leader
|
||||||
|
|
||||||
return BiRebot102Leader(config)
|
return BiRebotArm102Leader(config)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return cast("Teleoperator", make_device_from_device_class(config))
|
return cast("Teleoperator", make_device_from_device_class(config))
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
|
||||||
|
|
||||||
|
|
||||||
class BimanualMixin:
|
|
||||||
"""Lifecycle delegation for bimanual robots and teleoperators.
|
|
||||||
|
|
||||||
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
|
|
||||||
their own ``__init__``. They retain ownership of feature dicts and the
|
|
||||||
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
|
|
||||||
``send_feedback``), which vary per-embodiment.
|
|
||||||
|
|
||||||
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
|
|
||||||
take precedence in the MRO::
|
|
||||||
|
|
||||||
class BiFooFollower(BimanualMixin, Robot): ...
|
|
||||||
"""
|
|
||||||
|
|
||||||
left_arm: Any
|
|
||||||
right_arm: Any
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
|
||||||
|
|
||||||
@check_if_already_connected
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
self.left_arm.connect(calibrate)
|
|
||||||
self.right_arm.connect(calibrate)
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
self.left_arm.calibrate()
|
|
||||||
self.right_arm.calibrate()
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
self.left_arm.configure()
|
|
||||||
self.right_arm.configure()
|
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
self.left_arm.disconnect()
|
|
||||||
self.right_arm.disconnect()
|
|
||||||
@@ -28,7 +28,6 @@ import pytest
|
|||||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||||
|
|
||||||
import pandas as pd # noqa: E402
|
|
||||||
import pyarrow.parquet as pq # noqa: E402
|
import pyarrow.parquet as pq # noqa: E402
|
||||||
|
|
||||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||||
@@ -345,78 +344,6 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
|
|||||||
assert len(dataset) == 24
|
assert len(dataset) == 24
|
||||||
|
|
||||||
|
|
||||||
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
|
|
||||||
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
|
|
||||||
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
|
|
||||||
from lerobot.datasets.io_utils import write_tasks
|
|
||||||
from lerobot.utils.io_utils import write_json
|
|
||||||
|
|
||||||
data_dir = root / "data" / "chunk-000"
|
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
|
|
||||||
for ep, length in enumerate(episode_lengths):
|
|
||||||
episode_index += [ep] * length
|
|
||||||
frame_index += list(range(length))
|
|
||||||
timestamp += [round(i / fps, 6) for i in range(length)]
|
|
||||||
task_index += [0] * length
|
|
||||||
subtask_index += [0] * length # legacy column the writer must drop
|
|
||||||
pd.DataFrame(
|
|
||||||
{
|
|
||||||
"episode_index": episode_index,
|
|
||||||
"frame_index": frame_index,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"task_index": task_index,
|
|
||||||
"subtask_index": subtask_index,
|
|
||||||
}
|
|
||||||
).to_parquet(data_dir / "file-000.parquet", index=False)
|
|
||||||
|
|
||||||
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
|
|
||||||
write_tasks(tasks_df, root)
|
|
||||||
write_json(
|
|
||||||
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
|
|
||||||
root / "meta" / "info.json",
|
|
||||||
)
|
|
||||||
return root
|
|
||||||
|
|
||||||
|
|
||||||
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
|
|
||||||
"""Rewriting a packed shard must keep one row group per episode, not collapse
|
|
||||||
every episode into a single giant row group."""
|
|
||||||
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
|
|
||||||
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
|
|
||||||
shard = root / "data" / "chunk-000" / "file-000.parquet"
|
|
||||||
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
|
|
||||||
|
|
||||||
staging_dir = tmp_path / "stage"
|
|
||||||
for ep in range(len(episode_lengths)):
|
|
||||||
_stage_episode(
|
|
||||||
staging_dir,
|
|
||||||
ep,
|
|
||||||
plan=[
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"subtask for ep {ep}",
|
|
||||||
"style": "subtask",
|
|
||||||
"timestamp": 0.0,
|
|
||||||
"tool_calls": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
records = list(iter_episodes(root))
|
|
||||||
LanguageColumnsWriter().write_all(records, staging_dir, root)
|
|
||||||
|
|
||||||
# One row group per episode, with row counts matching the episode lengths.
|
|
||||||
md = pq.ParquetFile(shard).metadata
|
|
||||||
assert md.num_row_groups == len(episode_lengths)
|
|
||||||
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
|
|
||||||
# Language columns are still present after the per-episode rewrite.
|
|
||||||
table = pq.read_table(shard)
|
|
||||||
assert "language_persistent" in table.column_names
|
|
||||||
assert "language_events" in table.column_names
|
|
||||||
|
|
||||||
|
|
||||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||||
assert atom["role"] == "assistant"
|
assert atom["role"] == "assistant"
|
||||||
|
|||||||
@@ -32,26 +32,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|||||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
|
|
||||||
|
|
||||||
def assert_data_shards_one_row_group_per_episode(root):
|
|
||||||
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
|
|
||||||
import pyarrow.parquet as pq
|
|
||||||
|
|
||||||
shards = sorted((root / "data").rglob("*.parquet"))
|
|
||||||
assert shards, f"no data shards found under {root}/data"
|
|
||||||
n_episodes = 0
|
|
||||||
for shard in shards:
|
|
||||||
pf = pq.ParquetFile(shard)
|
|
||||||
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
|
|
||||||
assert pf.metadata.num_row_groups == len(set(episodes)), shard
|
|
||||||
for i in range(pf.metadata.num_row_groups):
|
|
||||||
rg_episodes = set(
|
|
||||||
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
|
|
||||||
)
|
|
||||||
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
|
|
||||||
n_episodes += len(set(episodes))
|
|
||||||
return n_episodes
|
|
||||||
|
|
||||||
|
|
||||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||||
"""Test that total number of episodes and frames are correctly aggregated."""
|
"""Test that total number of episodes and frames are correctly aggregated."""
|
||||||
assert aggr_ds.num_episodes == expected_episodes, (
|
assert aggr_ds.num_episodes == expected_episodes, (
|
||||||
@@ -586,41 +566,6 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
|
|
||||||
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
|
|
||||||
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
|
|
||||||
|
|
||||||
Covers both the non-image (``df.to_parquet``) and image
|
|
||||||
(``to_parquet_with_hf_images``) write branches, including the merge-into-
|
|
||||||
existing-file branch via a low file-size threshold that forces packing.
|
|
||||||
"""
|
|
||||||
ds_0 = lerobot_dataset_factory(
|
|
||||||
root=tmp_path / "rg_0",
|
|
||||||
repo_id=f"{DUMMY_REPO_ID}_rg_0",
|
|
||||||
total_episodes=3,
|
|
||||||
total_frames=60,
|
|
||||||
use_videos=use_videos,
|
|
||||||
)
|
|
||||||
ds_1 = lerobot_dataset_factory(
|
|
||||||
root=tmp_path / "rg_1",
|
|
||||||
repo_id=f"{DUMMY_REPO_ID}_rg_1",
|
|
||||||
total_episodes=4,
|
|
||||||
total_frames=80,
|
|
||||||
use_videos=use_videos,
|
|
||||||
)
|
|
||||||
|
|
||||||
aggr_root = tmp_path / "rg_aggr"
|
|
||||||
aggregate_datasets(
|
|
||||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
|
||||||
roots=[ds_0.root, ds_1.root],
|
|
||||||
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
|
|
||||||
aggr_root=aggr_root,
|
|
||||||
)
|
|
||||||
|
|
||||||
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
|
|
||||||
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
|
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
|
|||||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
from tests.mocks.mock_robot import MockRobotConfig
|
from tests.mocks.mock_robot import MockRobotConfig
|
||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
@@ -133,21 +133,6 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
|
|
||||||
dataset = empty_lerobot_dataset_factory(
|
|
||||||
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
|
|
||||||
)
|
|
||||||
dataset_copy = empty_lerobot_dataset_factory(
|
|
||||||
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
|
|
||||||
)
|
|
||||||
|
|
||||||
original_shape = dataset.meta.info.features["state"]["shape"]
|
|
||||||
dataset_copy.meta.info.features["state"]["shape"] = (999,)
|
|
||||||
|
|
||||||
assert dataset.meta.info.features["state"]["shape"] == original_shape
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import torch
|
|||||||
|
|
||||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
@@ -300,6 +301,29 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
|||||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
|
||||||
|
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
|
||||||
|
policy_cls = get_policy_class("act")
|
||||||
|
policy_cfg = make_policy_config("act")
|
||||||
|
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||||
|
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||||
|
policy_cfg.input_features = {
|
||||||
|
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||||
|
}
|
||||||
|
policy = policy_cls(policy_cfg)
|
||||||
|
policy.to(policy_cfg.device)
|
||||||
|
|
||||||
|
save_dir = tmp_path / "fsdp_state_dict"
|
||||||
|
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
|
||||||
|
|
||||||
|
# A single, unsharded safetensors file (no sharded set + index).
|
||||||
|
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
|
||||||
|
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
|
||||||
|
|
||||||
|
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||||
|
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("multikey", [True, False])
|
@pytest.mark.parametrize("multikey", [True, False])
|
||||||
def test_multikey_construction(multikey: bool):
|
def test_multikey_construction(multikey: bool):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2370,32 +2370,14 @@ def test_aggregate_images_when_use_videos_false():
|
|||||||
out = aggregate_pipeline_dataset_features(
|
out = aggregate_pipeline_dataset_features(
|
||||||
pipeline=rp,
|
pipeline=rp,
|
||||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||||
use_videos=False, # images kept, stored as "image" dtype
|
use_videos=False, # expect "image" dtype
|
||||||
patterns=None,
|
patterns=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
key = f"{OBS_IMAGES}.back"
|
key = f"{OBS_IMAGES}.back"
|
||||||
key_front = f"{OBS_IMAGES}.front"
|
key_front = f"{OBS_IMAGES}.front"
|
||||||
assert key in out
|
assert key not in out
|
||||||
assert key_front in out
|
assert key_front not in out
|
||||||
assert out[key]["dtype"] == "image"
|
|
||||||
assert out[key_front]["dtype"] == "image"
|
|
||||||
assert out[key]["shape"] == initial["back"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_images_excluded():
|
|
||||||
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
|
|
||||||
initial = {"back": (480, 640, 3)}
|
|
||||||
|
|
||||||
out = aggregate_pipeline_dataset_features(
|
|
||||||
pipeline=rp,
|
|
||||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
|
||||||
exclude_images=True,
|
|
||||||
patterns=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert f"{OBS_IMAGES}.back" not in out
|
|
||||||
assert f"{OBS_IMAGES}.front" not in out
|
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_images_when_use_videos_true():
|
def test_aggregate_images_when_use_videos_true():
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
|
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
|
||||||
from lerobot.teleoperators.rebot_102_leader import (
|
from lerobot.teleoperators.rebot_102_leader import (
|
||||||
RebotArm102Leader,
|
RebotArm102Leader,
|
||||||
RebotArm102LeaderConfig,
|
RebotArm102LeaderConfig,
|
||||||
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
|
|||||||
|
|
||||||
def test_bimanual_prefixes_features():
|
def test_bimanual_prefixes_features():
|
||||||
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
|
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
|
||||||
cfg = BiRebot102LeaderConfig(
|
cfg = BiRebotArm102LeaderConfig(
|
||||||
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
|
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
|
||||||
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
|
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
|
||||||
)
|
)
|
||||||
teleop = BiRebot102Leader(cfg)
|
teleop = BiRebotArm102Leader(cfg)
|
||||||
assert any(k.startswith("left_") for k in teleop.action_features)
|
assert any(k.startswith("left_") for k in teleop.action_features)
|
||||||
assert any(k.startswith("right_") for k in teleop.action_features)
|
assert any(k.startswith("right_") for k in teleop.action_features)
|
||||||
assert "left_gripper.pos" in teleop.action_features
|
assert "left_gripper.pos" in teleop.action_features
|
||||||
|
|||||||
Reference in New Issue
Block a user