Compare commits

...

13 Commits

Author SHA1 Message Date
Maxime Ellerbach 73782447f2 feat(train): FSDP checkpoint saving (#3810)
* feat(train): FSDP checkpoint saving

* adding docs for FSDP

* adding a test for the fsdp checkpoint path

* cleanup

* fixing final upload to hub

* refactored initial implementation to use torch fsdp api and adding new tests
2026-06-22 13:51:21 +02:00
Khalil Meftah 2d7a42011a fix(policies): support offline batch inference for ACT and Diffusion (#3822)
- Guard ACT's KL divergence computation against None latent params to
prevent crashes during eval when use_vae is set but the forward path
returns no VAE outputs.
- Add offline batch fallback to Diffusion's predict_action_chunk() so
it works with dataloader batches (empty queues) in addition to the
existing online rollout path (populated queues). This enables batched
action prediction for offline evaluation.
2026-06-21 11:48:45 +02:00
Khalil Meftah b06ad40888 feat(hub): add pretrained_revision to pin Hub model versions (#3820)
- Add pretrained_revision field to PreTrainedConfig (policies) and
RewardModelConfig (reward models), and thread it through make_policy(),
make_pre_post_processors(), and make_reward_model() so that weights and
processor configs can be loaded from a specific Hub commit, branch, or
tag. Defaults to None (latest version, preserving current behavior).
Dataset and env hub loading already supported revision pinning.

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-06-19 18:32:47 +02:00
Khalil Meftah b3d74f80f0 Fix batch wandb logging metrics and handle scalar stats (#3821)
* fix(logging): batch wandb metrics

- Batch all metrics into a single wandb.log() call instead of one per
key, reducing API overhead.

- Add support for list-valued metrics by expanding them to indexed keys (e.g.
metric_0, metric_1).

* fix(stats): handle scalar stats robustly

- Wrap cast_stats_to_numpy with np.atleast_1d to prevent 0-d arrays
from scalar stats causing shape mismatches downstream.

* fix(logging): remove unused list-valued metric expansion

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-06-19 18:31:12 +02:00
Khalil Meftah 552b4c3563 Add third-party env plugin discovery (#3823)
* feat(envs): add env plugin discovery

- Add 'lerobot_env_' to third-party plugin discovery prefixes, completing
the plugin system for all component types (robots, cameras, teleoperators,
policies, and now environments). External packages named lerobot_env_*
can self-register EnvConfig subclasses on import, enabling --env.type=
resolution without lerobot code changes.

* feat(envs): add generic observation passthrough

- Add generic observation passthrough in preprocess_observation() for
unhandled ndarray/tensor keys, replacing the pattern of adding per-env
hardcoded key handlers. Extra keys are forwarded as observation.<key>
and can be shaped by env-specific ProcessorSteps via get_env_processors().

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-06-19 18:30:00 +02:00
Nicolas Rabault 8bf6056d14 docs: add LeLab web interface to README (#3831) 2026-06-17 18:22:21 +02:00
Caroline Pascal da92db8fc0 fix(image transforms): cleaning up image_transforms implementation in LeRobotDataset (#3829) 2026-06-17 11:50:09 +02:00
Caroline Pascal 2b0834bcb8 fix(cameras): snapshot stop_event in read loops to avoid None deref (#3812)
* Do not set stop_event to None when stopping thread

* fix(cameras): snapshot stop_event in read loops to avoid None deref
The background read loops accessed self.stop_event repeatedly while
_stop_read_thread() can reassign it to None after join(). Reading the
attribute across the loop condition (and a mid-loop re-check) was a
time-of-check/time-of-use race: stop_event could flip to None between
the `is None` test and the `.is_set()` call, raising AttributeError on
the worker thread.
Snapshot self.stop_event into a local once, guard it, and loop on the
local Event. The Event object is thread-safe and lives for the thread's
lifetime; _stop_read_thread() always calls .set() before nulling the
attribute, so the local observes the stop and exits cleanly. This also
lets us drop the redundant pre-lock stop check.
Applies to OpenCVCamera, RealSenseCamera, and ZMQ camera.

---------

Co-authored-by: Anes Benmerzoug <anes.benmerzoug@gmail.com>
2026-06-17 11:40:17 +02:00
Caroline Pascal 287c823f13 fix(features copy): adding deepcopy on LeRobot dataset features to avoid shallow copy leaks (#3826)
* fix(features copy): adding deepcopy on LeRobot dataset features to avoid shallow copy leaks

* tests(test): adding new test
2026-06-16 17:58:59 +02:00
Pepijn 58ccc01508 fix(datasets): enforce one parquet row group per episode in v3 data writes (#3807)
* fix(datasets): enforce one parquet row group per episode in v3 data writes

LeRobot v3 data shards must hold exactly one row group per episode so a
reader can fetch episode i with pq.ParquetFile(path).read_row_group(i)
(a byte-range read) instead of loading the whole shard. The recording
writer already does this (one write_table per episode); the aggregate
and lerobot-annotate re-write paths instead concatenated many episodes
and wrote them in one shot, collapsing the file to a single row group.

- io_utils: add write_table_one_row_group_per_episode (one ParquetWriter,
  one write_table per episode — same pattern as the recording writer);
  to_parquet_with_hf_images embeds images then writes per-episode row
  groups; to_parquet_one_row_group_per_episode wraps it for plain frames
- aggregate: route non-image data writes through the per-episode writer;
  leave the episodes-metadata parquet untouched (already one row/episode)
- annotate: rewrite shards via the per-episode writer instead of a single
  bulk pq.write_table
- tests: invariant coverage through the aggregate (image + video) and
  annotate paths

No change to on-disk schema, paths, naming, rollover thresholds, or
compression. Readers stay backward-compatible (old collapsed files load).

* Update src/lerobot/datasets/io_utils.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update src/lerobot/datasets/io_utils.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(datasets): correct indentation and add strict= in row-group helper

The web-edited numpy version of write_table_one_row_group_per_episode had an
over-indented line (IndentationError, breaking pre-commit + test collection)
and a zip() without strict=. Fix both; behaviour unchanged.

---------

Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-06-16 12:15:48 +02:00
Caroline Pascal 38327fdc84 fix(images/videos): fixing aggregate_pipeline_dataset_features to avoid unwanted images features deletion (#3783)
* fix(images/videos): fixing aggregate_pipeline_dataset_features to avoid unwanted images features deletion when videos are not used

* fix(docstrings): improving docstrings

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>

---------

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-06-15 17:55:52 +02:00
Steven Palma 9555efc02c chore(dependencies): update uv.lock (#3595)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-06-15 16:29:44 +02:00
Steven Palma d576c59afb refactor(robots): homogenize bi-manual setups implementations (#3772)
* chore(robots): homogenize bi setups

* feat(robots): split openarm mini into single and bi

* refactor(robots): mixin for bi classes

* docs: update docs
2026-06-15 16:28:54 +02:00
65 changed files with 2150 additions and 1411 deletions
+1
View File
@@ -136,6 +136,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. - **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot. - **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
- **[LeLab](https://github.com/huggingface/leLab):** A web interface for LeRobot — teleoperate, calibrate, record datasets, replay, and train your SO arm from the browser, no CLI required.
## Citation ## Citation
+8 -8
View File
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
**Compatible teleoperators:** **Compatible teleoperators:**
- `openarm_mini` - OpenArm Mini - `bi_openarm_mini` - Bimanual OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm - `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT] > [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `openarm_mini`. > The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags. > `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
--- ---
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \ --robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \ --robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \ --robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \ --teleop.type=bi_openarm_mini \
--teleop.port_left=/dev/ttyACM0 \ --teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \ --teleop.right_arm_config.port=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_dataset \ --dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \ --dataset.single_task="Fold the T-shirt properly" \
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \ --robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \ --robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \ --robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \ --teleop.type=bi_openarm_mini \
--teleop.port_left=/dev/ttyACM0 \ --teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \ --teleop.right_arm_config.port=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \ --dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \ --dataset.single_task="Fold the T-shirt properly" \
+1 -1
View File
@@ -117,7 +117,7 @@ lerobot-rollout \
--strategy.num_episodes=20 \ --strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \ --policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \ --robot.type=bi_openarm_follower \
--teleop.type=openarm_mini \ --teleop.type=bi_openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \ --dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt" --dataset.single_task="Fold the T-shirt"
``` ```
+55
View File
@@ -113,6 +113,61 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
--policy=act --policy=act
``` ```
## Training Large Models with FSDP
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
```bash
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=<your_policy> \
--output_dir=outputs/train/my_policy_fsdp
```
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: bf16
num_machines: 1
num_processes: 4
fsdp_config:
fsdp_version: 1
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
fsdp_state_dict_type: FULL_STATE_DICT
```
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
optimizer before `accelerator.prepare()`.
### FSDP checkpoints
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two things to look out for:
- **Checkpoints store fp32 weights.** Under mixed precision (`bf16`/`fp16`) FSDP keeps an fp32 master
copy, and the checkpoint saves it (~2× the bf16 size on disk) so training can resume consistently
with the fp32 optimizer state; `from_pretrained` casts back to the policy dtype on load. FSDP-specific
caveat: an fp32 checkpoint is materialized in full precision on the target device _before_ casting,
so loading it for inference on a tight GPU can OOM even when the bf16 model would fit — load on CPU
first, or cast `model.safetensors` to the deployment dtype offline.
- The sharded optimizer state is gathered into a full (world-size-independent) state dict and saved
alongside the model in the same `optimizer_state.safetensors` / `optimizer_param_groups.json`
format as single-GPU training, so **resume-from-checkpoint is supported** with `--resume=true`.
Resume reshards both the model and the optimizer state to the _current_ FSDP topology, so you can
resume an FSDP checkpoint on a different number of GPUs. Note that the data sampler is only
sample-exact when the world size and batch size match the original run (a warning is logged
otherwise); the optimizer/model state itself is unaffected.
## Notes ## Notes
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration. - The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
@@ -54,6 +54,7 @@ from typing import Any
import pyarrow as pa import pyarrow as pa
import pyarrow.parquet as pq import pyarrow.parquet as pq
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
from lerobot.datasets.language import ( from lerobot.datasets.language import (
EVENT_ONLY_STYLES, EVENT_ONLY_STYLES,
LANGUAGE_EVENTS, LANGUAGE_EVENTS,
@@ -274,12 +275,11 @@ class LanguageColumnsWriter:
new_table = self._materialize_table( new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
) )
# Atomic replace: write to a sibling tmp path and rename so a crash # Re-emit one row group per episode (a bulk pq.write_table would collapse
# mid-write can't leave a half-written shard that ``pq.read_table`` # them into one). Write to a sibling tmp path and atomically rename so a
# would then fail to open. ``Path.replace`` is atomic on POSIX + # crash mid-write can't leave a half-written shard.
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp") tmp_path = path.with_suffix(path.suffix + ".tmp")
pq.write_table(new_table, tmp_path) write_table_one_row_group_per_episode(new_table, tmp_path)
tmp_path.replace(path) tmp_path.replace(path)
def _materialize_table( def _materialize_table(
+3 -2
View File
@@ -442,11 +442,12 @@ class OpenCVCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues. Stops on DeviceNotConnectedError, logs other errors and continues.
""" """
if self.stop_event is None: stop_event = self.stop_event
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0 failure_count = 0
while not self.stop_event.is_set(): while not stop_event.is_set():
try: try:
raw_frame = self._read_from_hardware() raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame) processed_frame = self._postprocess_image(raw_frame)
@@ -471,11 +471,12 @@ class RealSenseCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues. Stops on DeviceNotConnectedError, logs other errors and continues.
""" """
if self.stop_event is None: stop_event = self.stop_event
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0 failure_count = 0
while not self.stop_event.is_set(): while not stop_event.is_set():
try: try:
frame = self._read_from_hardware() frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame() color_frame_raw = frame.get_color_frame()
+3 -2
View File
@@ -246,11 +246,12 @@ class ZMQCamera(Camera):
""" """
Internal loop run by the background thread for asynchronous reading. Internal loop run by the background thread for asynchronous reading.
""" """
if self.stop_event is None: stop_event = self.stop_event
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.") raise RuntimeError(f"{self}: stop_event is not initialized.")
failure_count = 0 failure_count = 0
while not self.stop_event.is_set(): while not stop_event.is_set():
try: try:
frame = self._read_from_hardware() frame = self._read_from_hardware()
capture_time = time.perf_counter() capture_time = time.perf_counter()
+83 -5
View File
@@ -21,6 +21,7 @@ from torch.optim.lr_scheduler import LRScheduler
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.optim import ( from lerobot.optim import (
load_optimizer_state, load_optimizer_state,
load_optimizer_state_dict,
load_scheduler_state, load_scheduler_state,
save_optimizer_state, save_optimizer_state,
save_scheduler_state, save_scheduler_state,
@@ -98,6 +99,8 @@ def save_checkpoint(
postprocessor: PolicyProcessorPipeline | None = None, postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None, num_processes: int | None = None,
batch_size: int | None = None, batch_size: int | None = None,
model_state_dict: dict | None = None,
optim_state_dict: dict | None = None,
) -> None: ) -> None:
"""This function creates the following directory structure: """This function creates the following directory structure:
@@ -127,9 +130,18 @@ def save_checkpoint(
resume. Defaults to None (not recorded). resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded). resume. Defaults to None (not recorded).
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
cross-rank collective and passes it here so rank 0 can write it directly. It holds
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
optim_state_dict: Pre-gathered full (unsharded) optimizer state dict. Required under FSDP
(gathered alongside `model_state_dict` via `gather_fsdp_state_dicts`); saved in the same
safetensors format as the single-GPU path. When None, `optimizer.state_dict()` is used.
Defaults to None.
""" """
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir) policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
cfg.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir)
if cfg.peft is not None: if cfg.peft is not None:
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the # When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
@@ -140,7 +152,13 @@ def save_checkpoint(
if postprocessor is not None: if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir) postprocessor.save_pretrained(pretrained_dir)
save_training_state( save_training_state(
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size checkpoint_dir,
step,
optimizer,
scheduler,
num_processes=num_processes,
batch_size=batch_size,
optim_state_dict=optim_state_dict,
) )
@@ -151,6 +169,7 @@ def save_training_state(
scheduler: LRScheduler | None = None, scheduler: LRScheduler | None = None,
num_processes: int | None = None, num_processes: int | None = None,
batch_size: int | None = None, batch_size: int | None = None,
optim_state_dict: dict | None = None,
) -> None: ) -> None:
""" """
Saves the training step, optimizer state, scheduler state, and rng state. Saves the training step, optimizer state, scheduler state, and rng state.
@@ -164,19 +183,21 @@ def save_training_state(
Defaults to None. Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None. num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None. batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
optim_state_dict: Pre-gathered full optimizer state dict (for FSDP). Saved instead of
`optimizer.state_dict()` when provided. Defaults to None.
""" """
save_dir = checkpoint_dir / TRAINING_STATE_DIR save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size) save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_rng_state(save_dir) save_rng_state(save_dir)
if optimizer is not None: if optimizer is not None:
save_optimizer_state(optimizer, save_dir) save_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
if scheduler is not None: if scheduler is not None:
save_scheduler_state(scheduler, save_dir) save_scheduler_state(scheduler, save_dir)
def load_training_state( def load_training_state(
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None, load_optimizer: bool = True
) -> tuple[int, Optimizer, LRScheduler | None]: ) -> tuple[int, Optimizer, LRScheduler | None]:
""" """
Loads the training step, optimizer state, scheduler state, and rng state. Loads the training step, optimizer state, scheduler state, and rng state.
@@ -186,6 +207,10 @@ def load_training_state(
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir. checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
optimizer (Optimizer): The optimizer to load the state_dict to. optimizer (Optimizer): The optimizer to load the state_dict to.
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None). scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
load_optimizer (bool, optional): Whether to load the optimizer state from disk. Defaults to
True. Set to False under FSDP, where the sharded optimizer state must be loaded after
`accelerator.prepare()` via `load_fsdp_optimizer_state` (the optimizer is returned
untouched here).
Raises: Raises:
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
@@ -200,8 +225,61 @@ def load_training_state(
load_rng_state(training_state_dir) load_rng_state(training_state_dir)
step = load_training_step(training_state_dir) step = load_training_step(training_state_dir)
optimizer = load_optimizer_state(optimizer, training_state_dir) if load_optimizer:
optimizer = load_optimizer_state(optimizer, training_state_dir)
if scheduler is not None: if scheduler is not None:
scheduler = load_scheduler_state(scheduler, training_state_dir) scheduler = load_scheduler_state(scheduler, training_state_dir)
return step, optimizer, scheduler return step, optimizer, scheduler
def gather_fsdp_state_dicts(model, optimizer) -> tuple[dict, dict]:
"""Gather the full (unsharded) model and optimizer state dicts under FSDP.
`model.state_dict()` and `FSDP.optim_state_dict(...)` are cross-rank collectives, so this must be
called on *every* rank with the prepared (FSDP-wrapped) `model` and `optimizer`. With
`rank0_only=True` and `offload_to_cpu=True`, every rank runs the all-gather but only rank 0
materializes the full dicts (the others get empty dicts) and they are kept on CPU to bound GPU
memory. The returned optimizer state dict is keyed by parameter FQNs and is world-size
independent; `load_fsdp_optimizer_state` reshards it on resume.
Returns:
(model_state_dict, optim_state_dict): full dicts on rank 0, empty dicts on other ranks.
"""
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP, # noqa F401
StateDictType,
)
state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
optim_cfg = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
model_state_dict = model.state_dict()
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
return model_state_dict, optim_state_dict
def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
"""Load the FSDP optimizer state (saved as safetensors) and reshard it into the optimizer.
This is a cross-rank collective and must be called on every rank *after* `accelerator.prepare()`
with the prepared (FSDP-wrapped) `model` and `optimizer`. The saved state is the full,
world-size-independent optimizer state (keyed by parameter FQNs); `FSDP.optim_state_dict_to_load`
reshards it to the current FSDP topology, so resume on a different number of GPUs works.
"""
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP, # noqa F401
StateDictType,
)
# Every rank reads the same full state from the (shared) checkpoint dir, so rank0_only=False.
full_osd = load_optimizer_state_dict(checkpoint_dir / TRAINING_STATE_DIR)
state_cfg = FullStateDictConfig(rank0_only=False)
optim_cfg = FullOptimStateDictConfig(rank0_only=False)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
optimizer.load_state_dict(sharded_osd)
+11 -9
View File
@@ -180,24 +180,26 @@ class WandBLogger:
self._wandb_custom_step_key.add(new_custom_key) self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True) self._wandb.define_metric(new_custom_key, hidden=True)
batch_data = {}
for k, v in d.items(): for k, v in d.items():
# Skip the custom step key here, it's added to the batch below.
if custom_step_key is not None and k == custom_step_key:
continue
if not isinstance(v, (int | float | str)): if not isinstance(v, (int | float | str)):
logging.warning( logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
) )
continue continue
# Do not log the custom step key itself. batch_data[f"{mode}/{k}"] = v
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if batch_data:
if custom_step_key is not None: if custom_step_key is not None:
value_custom_step = d[custom_step_key] batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step} self._wandb.log(batch_data)
self._wandb.log(data) else:
continue self._wandb.log(data=batch_data, step=step)
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"): def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}: if mode not in {"train", "eval"}:
+2
View File
@@ -79,6 +79,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch. # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: Path | None = None pretrained_path: Path | None = None
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
pretrained_revision: str | None = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device): if not self.device or not is_torch_device_available(self.device):
+2
View File
@@ -56,6 +56,8 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
device: str | None = None device: str | None = None
pretrained_path: str | None = None pretrained_path: str | None = None
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
pretrained_revision: str | None = None
push_to_hub: bool = False push_to_hub: bool = False
repo_id: str | None = None repo_id: str | None = None
+9
View File
@@ -32,6 +32,7 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
from .io_utils import ( from .io_utils import (
get_file_size_in_mb, get_file_size_in_mb,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
to_parquet_one_row_group_per_episode,
to_parquet_with_hf_images, to_parquet_with_hf_images,
write_info, write_info,
write_stats, write_stats,
@@ -551,6 +552,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
aggr_root=dst_meta.root, aggr_root=dst_meta.root,
hf_features=hf_features, hf_features=hf_features,
concatenate=concatenate_data, concatenate=concatenate_data,
one_row_group_per_episode=True,
) )
# Record the mapping from source to actual destination # Record the mapping from source to actual destination
@@ -628,6 +630,7 @@ def append_or_create_parquet_file(
aggr_root: Path = None, aggr_root: Path = None,
hf_features: datasets.Features | None = None, hf_features: datasets.Features | None = None,
concatenate: bool = True, concatenate: bool = True,
one_row_group_per_episode: bool = False,
) -> tuple[dict[str, int], tuple[int, int]]: ) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints. """Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -645,6 +648,8 @@ def append_or_create_parquet_file(
aggr_root: Root path for the aggregated dataset. aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing. hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one. concatenate: When False, always rotate to a new file instead of appending to the current one.
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
the episodes-metadata parquet (already one row per episode).
Returns: Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -657,6 +662,8 @@ def append_or_create_parquet_file(
dst_path.parent.mkdir(parents=True, exist_ok=True) dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images: if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features) to_parquet_with_hf_images(df, dst_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(df, dst_path)
else: else:
df.to_parquet(dst_path) df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file) return idx, (dst_chunk, dst_file)
@@ -683,6 +690,8 @@ def append_or_create_parquet_file(
if contains_images: if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features) to_parquet_with_hf_images(final_df, target_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(final_df, target_path)
else: else:
final_df.to_parquet(target_path) final_df.to_parquet(target_path)
+2 -1
View File
@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -709,7 +710,7 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False) obj.root.mkdir(parents=True, exist_ok=False)
features = {**features, **DEFAULT_FEATURES} features = {**deepcopy(features), **DEFAULT_FEATURES}
_validate_feature_names(features) _validate_feature_names(features)
obj.tasks = None obj.tasks = None
+12
View File
@@ -74,6 +74,8 @@ class DatasetReader:
self.episodes = episodes self.episodes = episodes
self._tolerance_s = tolerance_s self._tolerance_s = tolerance_s
self._video_backend = video_backend self._video_backend = video_backend
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms self._image_transforms = image_transforms
self._return_uint8 = return_uint8 self._return_uint8 = return_uint8
@@ -86,6 +88,16 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms
def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self._image_transforms = None
def try_load(self) -> bool: def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient.""" """Attempt to load from local cache. Returns True if data is sufficient."""
try: try:
+4 -1
View File
@@ -27,6 +27,7 @@ import logging
import shutil import shutil
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from copy import deepcopy
from pathlib import Path from pathlib import Path
import datasets import datasets
@@ -1101,7 +1102,9 @@ def _copy_episodes_metadata_and_stats(
if dst_meta.video_keys and src_dataset.meta.video_keys: if dst_meta.video_keys and src_dataset.meta.video_keys:
for key in dst_meta.video_keys: for key in dst_meta.video_keys:
if key in src_dataset.meta.features: if key in src_dataset.meta.features:
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {}) dst_meta.info.features[key]["info"] = deepcopy(
src_dataset.meta.info.features[key].get("info", {})
)
write_info(dst_meta.info, dst_meta.root) write_info(dst_meta.info, dst_meta.root)
+39 -10
View File
@@ -20,6 +20,7 @@ import datasets
import numpy as np import numpy as np
import pandas import pandas
import pandas as pd import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq import pyarrow.parquet as pq
import torch import torch
@@ -153,7 +154,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
Returns: Returns:
dict: The statistics dictionary with values cast to numpy arrays. dict: The statistics dictionary with values cast to numpy arrays.
""" """
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats) return unflatten_dict(stats)
@@ -270,21 +271,49 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
return items_dict return items_dict
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
"""Write ``table`` with one parquet row group per episode (in episode order).
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
mirroring the recording writer. ``table`` must carry a contiguous
``episode_index`` column.
"""
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
try:
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
finally:
writer.close()
def to_parquet_with_hf_images( def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None: ) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. """Write a DataFrame with HF-encoded images to parquet, one row group per episode.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Args: Images are embedded into the arrow table first (``ParquetWriter.write_table``
df: DataFrame to write to parquet. does not embed external image files like ``Dataset.to_parquet`` does).
path: Path to write the parquet file. ``features`` types image columns as ``Image()`` in the parquet schema.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
""" """
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds.to_parquet(path) ds = embed_images(ds)
table = ds.with_format("arrow")[:]
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
# No episode boundaries to align row groups to — keep a single write.
pq.write_table(table, str(path))
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
table = pa.Table.from_pandas(df, preserve_index=False)
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
pq.write_table(table, str(path))
def item_to_torch(item: dict) -> dict: def item_to_torch(item: dict) -> dict:
+5 -7
View File
@@ -201,8 +201,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self._requested_root = Path(root) if root else None self._requested_root = Path(root) if root else None
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
@@ -249,6 +247,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms=image_transforms, image_transforms=image_transforms,
return_uint8=self._return_uint8, return_uint8=self._return_uint8,
) )
self.image_transforms = image_transforms
# Load actual data # Load actual data
if force_cache_sync or not self.reader.try_load(): if force_cache_sync or not self.reader.try_load():
@@ -505,15 +504,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
def set_image_transforms(self, image_transforms: Callable | None) -> None: def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations.""" """Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms): self._ensure_reader().set_image_transforms(image_transforms)
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms self.image_transforms = image_transforms
if self.reader is not None:
self.reader._image_transforms = image_transforms
def clear_image_transforms(self) -> None: def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations.""" """Remove the transform applied to visual observations."""
self.set_image_transforms(None) if self.reader is not None:
self.reader.set_image_transforms(None)
self.image_transforms = None
# ── Hub methods (stay on facade) ────────────────────────────────── # ── Hub methods (stay on facade) ──────────────────────────────────
+5 -3
View File
@@ -70,19 +70,21 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]], initial_features: dict[PipelineFeatureType, dict[str, Any]],
*, *,
use_videos: bool = True, use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None, patterns: Sequence[str] | None = None,
) -> dict[str, dict]: ) -> dict[str, dict]:
""" """
Aggregates and filters pipeline features to create a dataset-ready features dictionary. Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `use_videos` and `patterns`, and finally (image or state), filters them based on `exclude_images` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset. formats them for use with a Hugging Face LeRobot Dataset.
Args: Args:
pipeline: The DataProcessorPipeline to apply. pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations. initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: If False, image features are excluded. use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
exclude_images: If True, image features are dropped entirely from the output.
patterns: A sequence of regex patterns to filter action and state features. patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter. Image features are not affected by this filter.
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
) )
# 2. Apply filtering rules. # 2. Apply filtering rules.
if is_image and not use_videos: if is_image and exclude_images:
continue continue
if not is_image and not should_keep(key, compiled_patterns): if not is_image and not should_keep(key, compiled_patterns):
continue continue
+20
View File
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "camera_obs" in observations: if "camera_obs" in observations:
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"] return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
# Pass through any remaining ndarray/tensor keys not already handled above,
# so env plugins can expose extra observation keys via get_env_processors().
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
for key, value in observations.items():
if key in _handled:
continue
target = f"{OBS_STR}.{key}"
if target in return_observations:
continue
if isinstance(value, np.ndarray):
val = torch.from_numpy(value).float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
elif isinstance(value, Tensor):
val = value.float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
return return_observations return return_observations
+2
View File
@@ -20,6 +20,7 @@ from .optimizers import (
SGDConfig as SGDConfig, SGDConfig as SGDConfig,
XVLAAdamWConfig as XVLAAdamWConfig, XVLAAdamWConfig as XVLAAdamWConfig,
load_optimizer_state, load_optimizer_state,
load_optimizer_state_dict,
save_optimizer_state, save_optimizer_state,
) )
from .schedulers import ( from .schedulers import (
@@ -50,6 +51,7 @@ __all__ = [
"VQBeTSchedulerConfig", "VQBeTSchedulerConfig",
# State management # State management
"load_optimizer_state", "load_optimizer_state",
"load_optimizer_state_dict",
"load_scheduler_state", "load_scheduler_state",
"save_optimizer_state", "save_optimizer_state",
"save_scheduler_state", "save_scheduler_state",
+30 -5
View File
@@ -27,7 +27,7 @@ from lerobot.utils.constants import (
OPTIMIZER_PARAM_GROUPS, OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE, OPTIMIZER_STATE,
) )
from lerobot.utils.io_utils import deserialize_json_into_object, write_json from lerobot.utils.io_utils import deserialize_json_into_object, load_json, write_json
from lerobot.utils.utils import flatten_dict, unflatten_dict from lerobot.utils.utils import flatten_dict, unflatten_dict
# Type alias for parameters accepted by optimizer build() methods. # Type alias for parameters accepted by optimizer build() methods.
@@ -281,28 +281,37 @@ class MultiAdamConfig(OptimizerConfig):
def save_optimizer_state( def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer],
save_dir: Path,
optim_state_dict: dict | None = None,
) -> None: ) -> None:
"""Save optimizer state to disk. """Save optimizer state to disk.
Args: Args:
optimizer: Either a single optimizer or a dictionary of optimizers. optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state. save_dir: Directory to save the optimizer state.
optim_state_dict: Pre-gathered optimizer state dict (for FSDP, where the sharded state must
be gathered across ranks first). If provided, it is saved directly instead of calling
``optimizer.state_dict()``. Only supported for a single optimizer. Defaults to None.
""" """
if isinstance(optimizer, dict): if isinstance(optimizer, dict):
# Handle dictionary of optimizers # Handle dictionary of optimizers
if optim_state_dict is not None:
raise ValueError("optim_state_dict is not supported for a dict of optimizers")
for name, opt in optimizer.items(): for name, opt in optimizer.items():
optimizer_dir = save_dir / name optimizer_dir = save_dir / name
optimizer_dir.mkdir(exist_ok=True, parents=True) optimizer_dir.mkdir(exist_ok=True, parents=True)
_save_single_optimizer_state(opt, optimizer_dir) _save_single_optimizer_state(opt, optimizer_dir)
else: else:
# Handle single optimizer # Handle single optimizer
_save_single_optimizer_state(optimizer, save_dir) _save_single_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: def _save_single_optimizer_state(
optimizer: torch.optim.Optimizer, save_dir: Path, optim_state_dict: dict | None = None
) -> None:
"""Save a single optimizer's state to disk.""" """Save a single optimizer's state to disk."""
state = optimizer.state_dict() state = dict(optim_state_dict) if optim_state_dict is not None else optimizer.state_dict()
param_groups = state.pop("param_groups") param_groups = state.pop("param_groups")
flat_state = flatten_dict(state) flat_state = flatten_dict(state)
save_file(flat_state, save_dir / OPTIMIZER_STATE) save_file(flat_state, save_dir / OPTIMIZER_STATE)
@@ -356,3 +365,19 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
optimizer.load_state_dict(loaded_state_dict) optimizer.load_state_dict(loaded_state_dict)
return optimizer return optimizer
def load_optimizer_state_dict(save_dir: Path) -> dict:
"""Read a saved optimizer state dict (safetensors + json) back into a plain dict.
Unlike `load_optimizer_state`, this does not load into an optimizer and preserves the original
``state`` keys verbatim (e.g. FSDP parameter FQNs, which are not integer-castable). It is used by
the FSDP resume path, where the full state must be resharded via `FSDP.optim_state_dict_to_load`
before being loaded into the (sharded) optimizer.
"""
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
return {
"state": state.get("state", {}),
"param_groups": load_json(save_dir / OPTIMIZER_PARAM_GROUPS),
}
+1 -1
View File
@@ -148,7 +148,7 @@ class ACTPolicy(PreTrainedPolicy):
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1) l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
loss_dict = {"l1_loss": l1_loss.item()} loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae: if self.config.use_vae and log_sigma_x2_hat is not None:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total # each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch. # KL-divergence per batch element, then take the mean over the batch.
@@ -101,11 +101,23 @@ class DiffusionPolicy(PreTrainedPolicy):
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations.""" """Predict a chunk of actions given environment observations.
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch, noise=noise)
Supports two modes:
- Online (queues populated via select_action): stacks observations from internal queues.
- Offline (empty queues, e.g. dataloader batch): uses the batch directly.
"""
queues_populated = any(len(q) > 0 for q in self._queues.values())
if queues_populated:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
else:
batch = dict(batch)
if self.config.image_features:
for key in self.config.image_features:
if batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
actions = self.diffusion.generate_actions(batch, noise=noise)
return actions return actions
@torch.no_grad() @torch.no_grad()
+4
View File
@@ -252,6 +252,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
def make_pre_post_processors( def make_pre_post_processors(
policy_cfg: PreTrainedConfig, policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None, pretrained_path: str | None = None,
pretrained_revision: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs], **kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[ ) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -309,6 +310,7 @@ def make_pre_post_processors(
overrides=kwargs.get("preprocessor_overrides", {}), overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=batch_to_transition, to_transition=batch_to_transition,
to_output=transition_to_batch, to_output=transition_to_batch,
revision=pretrained_revision,
) )
postprocessor = PolicyProcessorPipeline.from_pretrained( postprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path, pretrained_model_name_or_path=pretrained_path,
@@ -318,6 +320,7 @@ def make_pre_post_processors(
overrides=kwargs.get("postprocessor_overrides", {}), overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=policy_action_to_transition, to_transition=policy_action_to_transition,
to_output=transition_to_policy_action, to_output=transition_to_policy_action,
revision=pretrained_revision,
) )
_reconnect_relative_absolute_steps(preprocessor, postprocessor) _reconnect_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor return preprocessor, postprocessor
@@ -557,6 +560,7 @@ def make_policy(
# Load a pretrained policy and override the config if needed (for example, if there are inference-time # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary). # hyperparameters that we want to vary).
kwargs["pretrained_name_or_path"] = cfg.pretrained_path kwargs["pretrained_name_or_path"] = cfg.pretrained_path
kwargs["revision"] = cfg.pretrained_revision
policy = policy_cls.from_pretrained(**kwargs) policy = policy_cls.from_pretrained(**kwargs)
elif cfg.pretrained_path and cfg.use_peft: elif cfg.pretrained_path and cfg.use_peft:
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo # Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
+39 -4
View File
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
import packaging import packaging
import safetensors import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
@@ -129,10 +129,43 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if not getattr(cls, "name", None): if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'") raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None: def save_pretrained(
self,
save_directory: str | Path,
*,
state_dict: dict[str, Tensor] | None = None,
repo_id: str | None = None,
push_to_hub: bool = False,
card_kwargs: dict | None = None,
**push_to_hub_kwargs,
) -> str | None:
"""Save the policy to a directory (and optionally push to the Hub).
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
return sharded tensors, so the caller gathers the full state dict via a cross-rank
collective and passes it here for `_save_pretrained` to write directly.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
self._save_pretrained(save_directory, state_dict=state_dict)
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return None
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
self.config._save_pretrained(save_directory) self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) if state_dict is None:
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
return
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
@classmethod @classmethod
def from_pretrained( def from_pretrained(
@@ -270,6 +303,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self, self,
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
peft_model=None, peft_model=None,
state_dict: dict[str, Tensor] | None = None,
): ):
api = HfApi() api = HfApi()
repo_id = api.create_repo( repo_id = api.create_repo(
@@ -287,7 +321,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
peft_model.save_pretrained(saved_path) peft_model.save_pretrained(saved_path)
self.config.save_pretrained(saved_path) self.config.save_pretrained(saved_path)
else: else:
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors # Calls _save_pretrained and stores model tensors
self.save_pretrained(saved_path, state_dict=state_dict)
card = self.generate_model_card( card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
+1
View File
@@ -124,6 +124,7 @@ def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel
if cfg.pretrained_path: if cfg.pretrained_path:
kwargs["pretrained_name_or_path"] = cfg.pretrained_path kwargs["pretrained_name_or_path"] = cfg.pretrained_path
kwargs["revision"] = cfg.pretrained_revision
reward_model = reward_cls.from_pretrained(**kwargs) reward_model = reward_cls.from_pretrained(**kwargs)
else: else:
reward_model = reward_cls(**kwargs) reward_model = reward_cls(**kwargs)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot from ..robot import Robot
@@ -27,7 +28,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiOpenArmFollower(Robot): class BiOpenArmFollower(BimanualMixin, Robot):
""" """
Bimanual OpenArm Follower Arms Bimanual OpenArm Follower Arms
""" """
@@ -39,15 +40,17 @@ class BiOpenArmFollower(Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are distributed evenly: each arm's OpenArmFollower # Top-level cameras are opened by `left_arm` for convenience, but their
# will only open the cameras assigned to it. Per-arm cameras are used # keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
# as fallback when top-level cameras are empty. self._top_level_cam_keys = set(config.cameras)
if config.cameras: _collisions = self._top_level_cam_keys & set(
left_cameras = config.cameras config.left_arm_config.cameras
right_cameras = {} ) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
else: if _collisions:
left_cameras = config.left_arm_config.cameras raise ValueError(
right_cameras = config.right_arm_config.cameras f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = OpenArmFollowerConfig( left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
@@ -56,7 +59,7 @@ class BiOpenArmFollower(Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque, use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_cameras, cameras=left_arm_cameras,
side=config.left_arm_config.side, side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface, can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd, use_can_fd=config.left_arm_config.use_can_fd,
@@ -75,7 +78,7 @@ class BiOpenArmFollower(Robot):
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque, use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target, max_relative_target=config.right_arm_config.max_relative_target,
cameras=right_cameras, cameras=config.right_arm_config.cameras,
side=config.right_arm_config.side, side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface, can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd, use_can_fd=config.right_arm_config.use_can_fd,
@@ -95,22 +98,19 @@ class BiOpenArmFollower(Robot):
@property @property
def _motors_ft(self) -> dict[str, type]: def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return { return {
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()}, **{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()}, **{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
} }
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base", out: dict[str, tuple] = {}
# "right_wrist"), so we merge them directly — unlike motors which need the for k, v in self.left_arm._cameras_ft.items():
# left_/right_ prefix to disambiguate identical per-arm joint names. out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft} for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -120,27 +120,6 @@ class BiOpenArmFollower(Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors." "Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -148,21 +127,15 @@ class BiOpenArmFollower(Robot):
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict = {} obs_dict: RobotObservation = {}
# Camera keys that should NOT get the arm prefix (they already have unique names) # Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
left_cam_keys = set(self.left_arm.cameras.keys()) for key, value in self.left_arm.get_observation().items():
right_cam_keys = set(self.right_arm.cameras.keys()) obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Right first, then left — matches the teleoperator (OpenArmMini) ordering # Add "right_" prefix
# and the dataset feature names recorded during data collection. for key, value in self.right_arm.get_observation().items():
right_obs = self.right_arm.get_observation() obs_dict[f"right_{key}"] = value
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict return obs_dict
@@ -189,9 +162,4 @@ class BiOpenArmFollower(Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()} prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_right, **prefixed_sent_action_left} return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -32,5 +32,7 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras shared across both arms. # Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict) cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot from ..robot import Robot
@@ -27,7 +28,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiRebotB601Follower(Robot): class BiRebotB601Follower(BimanualMixin, Robot):
"""Bimanual Seeed Studio reBot B601-DM follower. """Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and Composes two single-arm :class:`RebotB601Follower` instances. Observation and
@@ -41,6 +42,18 @@ class BiRebotB601Follower(Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = RebotB601FollowerRobotConfig( left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir, calibration_dir=config.calibration_dir,
@@ -49,7 +62,7 @@ class BiRebotB601Follower(Robot):
dm_serial_baud=config.left_arm_config.dm_serial_baud, dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
cameras=config.left_arm_config.cameras, cameras=left_arm_cameras,
motor_can_ids=config.left_arm_config.motor_can_ids, motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity, pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio, gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
@@ -86,10 +99,12 @@ class BiRebotB601Follower(Robot):
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
return { out: dict[str, tuple] = {}
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()}, for k, v in self.left_arm._cameras_ft.items():
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()}, out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
} for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -99,32 +114,13 @@ class BiRebotB601Follower(Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict = {} obs_dict: RobotObservation = {}
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()}) for k, v in self.left_arm.get_observation().items():
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()}) obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm.get_observation().items():
obs_dict[f"right_{k}"] = v
return obs_dict return obs_dict
@check_if_not_connected @check_if_not_connected
@@ -143,8 +139,3 @@ class BiRebotB601Follower(Robot):
**{f"left_{k}": v for k, v in sent_action_left.items()}, **{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()}, **{f"right_{k}": v for k, v in sent_action_right.items()},
} }
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig from ..rebot_b601_follower import RebotB601FollowerConfig
@@ -27,3 +29,8 @@ class BiRebotB601FollowerConfig(RobotConfig):
left_arm_config: RebotB601FollowerConfig left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig right_arm_config: RebotB601FollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction, RobotObservation from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..robot import Robot from ..robot import Robot
from ..so_follower import SOFollower, SOFollowerRobotConfig from ..so_follower import SOFollower, SOFollowerRobotConfig
@@ -27,7 +28,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiSOFollower(Robot): class BiSOFollower(BimanualMixin, Robot):
""" """
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio [Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
""" """
@@ -39,6 +40,18 @@ class BiSOFollower(Robot):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = SOFollowerRobotConfig( left_arm_config = SOFollowerRobotConfig(
id=f"{config.id}_left" if config.id else None, id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir, calibration_dir=config.calibration_dir,
@@ -46,7 +59,7 @@ class BiSOFollower(Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target, max_relative_target=config.left_arm_config.max_relative_target,
use_degrees=config.left_arm_config.use_degrees, use_degrees=config.left_arm_config.use_degrees,
cameras=config.left_arm_config.cameras, cameras=left_arm_cameras,
) )
right_arm_config = SOFollowerRobotConfig( right_arm_config = SOFollowerRobotConfig(
@@ -77,13 +90,12 @@ class BiSOFollower(Robot):
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
left_arm_cameras_ft = self.left_arm._cameras_ft out: dict[str, tuple] = {}
right_arm_cameras_ft = self.right_arm._cameras_ft for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
return { for k, v in self.right_arm._cameras_ft.items():
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()}, out[f"right_{k}"] = v
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()}, return out
}
@cached_property @cached_property
def observation_features(self) -> dict[str, type | tuple]: def observation_features(self) -> dict[str, type | tuple]:
@@ -93,42 +105,21 @@ class BiSOFollower(Robot):
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return self._motors_ft return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
self.left_arm.setup_motors() self.left_arm.setup_motors()
self.right_arm.setup_motors() self.right_arm.setup_motors()
@check_if_not_connected @check_if_not_connected
def get_observation(self) -> RobotObservation: def get_observation(self) -> RobotObservation:
obs_dict = {} obs_dict: RobotObservation = {}
# Add "left_" prefix # Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
left_obs = self.left_arm.get_observation() for key, value in self.left_arm.get_observation().items():
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()}) obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Add "right_" prefix # Add "right_" prefix
right_obs = self.right_arm.get_observation() for key, value in self.right_arm.get_observation().items():
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()}) obs_dict[f"right_{key}"] = value
return obs_dict return obs_dict
@@ -151,8 +142,3 @@ class BiSOFollower(Robot):
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right} return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig from ..config import RobotConfig
from ..so_follower import SOFollowerConfig from ..so_follower import SOFollowerConfig
@@ -27,3 +29,8 @@ class BiSOFollowerConfig(RobotConfig):
left_arm_config: SOFollowerConfig left_arm_config: SOFollowerConfig
right_arm_config: SOFollowerConfig right_arm_config: SOFollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
+1
View File
@@ -54,6 +54,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
@@ -57,6 +57,7 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
gamepad, gamepad,
+1
View File
@@ -137,6 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
+1
View File
@@ -174,6 +174,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
homunculus, homunculus,
@@ -41,6 +41,7 @@ from lerobot.robots import ( # noqa: F401
) )
from lerobot.teleoperators import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
koch_leader, koch_leader,
@@ -89,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
TeleoperatorConfig, TeleoperatorConfig,
bi_openarm_leader, bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader, bi_rebot_102_leader,
bi_so_leader, bi_so_leader,
gamepad, gamepad,
+29 -4
View File
@@ -34,8 +34,10 @@ from torch.optim import Optimizer
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.train_utils import ( from lerobot.common.train_utils import (
gather_fsdp_state_dicts,
get_step_checkpoint_dir, get_step_checkpoint_dir,
get_step_identifier, get_step_identifier,
load_fsdp_optimizer_state,
load_training_batch_size, load_training_batch_size,
load_training_num_processes, load_training_num_processes,
load_training_state, load_training_state,
@@ -189,6 +191,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
require_package("accelerate", extra="training") require_package("accelerate", extra="training")
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
cfg.validate() cfg.validate()
@@ -197,8 +200,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation # We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None: if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training). # Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
@@ -345,6 +346,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path, pretrained_path=processor_pretrained_path,
pretrained_revision=getattr(cfg.policy, "pretrained_revision", None),
**processor_kwargs, **processor_kwargs,
) )
@@ -370,7 +372,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
step = 0 # number of policy updates (forward + backward + optim) step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume: if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) # Under FSDP the optimizer state is sharded and must be loaded after `accelerator.prepare()`
# (see load_fsdp_optimizer_state below), so skip the optimizer here and load it then.
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
step, optimizer, lr_scheduler = load_training_state(
cfg.checkpoint_path, optimizer, lr_scheduler, load_optimizer=not is_fsdp
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@@ -460,6 +467,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler policy, optimizer, dataloader, lr_scheduler
) )
# FSDP optimizer state is sharded across ranks, so it can only be loaded once the optimizer and
# model are FSDP-wrapped (i.e. after `prepare`). Collective: every rank must participate.
if cfg.resume and accelerator.distributed_type == DistributedType.FSDP:
load_fsdp_optimizer_state(policy, optimizer, cfg.checkpoint_path)
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
@@ -558,6 +571,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
# Under FSDP, gathering the full model + optimizer state dicts is a cross-rank collective,
# so all ranks must participate; rank 0 then writes the materialized dicts. For DDP /
# single-GPU the state dicts are saved the normal way inside save_checkpoint.
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
if is_fsdp:
model_state_dict, optim_state_dict = gather_fsdp_state_dicts(policy, optimizer)
else:
model_state_dict, optim_state_dict = None, None
if is_main_process: if is_main_process:
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
@@ -572,6 +593,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
postprocessor=postprocessor, postprocessor=postprocessor,
num_processes=accelerator.num_processes, num_processes=accelerator.num_processes,
batch_size=cfg.batch_size, batch_size=cfg.batch_size,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
) )
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
if wandb_logger: if wandb_logger:
@@ -634,6 +657,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if eval_env: if eval_env:
close_envs(eval_env) close_envs(eval_env)
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None
if is_main_process: if is_main_process:
logging.info("End of training") logging.info("End of training")
@@ -643,7 +668,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if not cfg.is_reward_model_training and cfg.policy.use_peft: if not cfg.is_reward_model_training and cfg.policy.use_peft:
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model) unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
else: else:
unwrapped_model.push_model_to_hub(cfg) unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
preprocessor.push_to_hub(active_cfg.repo_id) preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id) postprocessor.push_to_hub(active_cfg.repo_id)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
@@ -27,7 +28,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiOpenArmLeader(Teleoperator): class BiOpenArmLeader(BimanualMixin, Teleoperator):
""" """
Bimanual OpenArm Leader Arms Bimanual OpenArm Leader Arms
""" """
@@ -86,27 +87,6 @@ class BiOpenArmLeader(Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors." "Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -129,8 +109,3 @@ class BiOpenArmLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback # TODO: Implement force feedback
raise NotImplementedError raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_leader") @TeleoperatorConfig.register_subclass("bi_openarm_leader")
@dataclass @dataclass
class BiOpenArmLeaderConfig(TeleoperatorConfig): class BiOpenArmLeaderConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Follower robots.""" """Configuration class for Bi OpenArm Leader teleoperators."""
left_arm_config: OpenArmLeaderConfigBase left_arm_config: OpenArmLeaderConfigBase
right_arm_config: OpenArmLeaderConfigBase right_arm_config: OpenArmLeaderConfigBase
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_mini import BiOpenArmMini
from .config_bi_openarm_mini import BiOpenArmMiniConfig
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
@@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
from ..teleoperator import Teleoperator
from .config_bi_openarm_mini import BiOpenArmMiniConfig
logger = logging.getLogger(__name__)
class BiOpenArmMini(BimanualMixin, Teleoperator):
"""Bimanual OpenArm Mini teleoperator.
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
bimanual leader can teleoperate a bimanual OpenArm follower.
"""
config_class = BiOpenArmMiniConfig
name = "bi_openarm_mini"
def __init__(self, config: BiOpenArmMiniConfig):
super().__init__(config)
self.config = config
# `side` is forced to match left/right regardless of what the user passed
# on the per-arm base config — the bimanual wrapper owns the side semantics.
left_arm_config = OpenArmMiniConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
side="left",
use_degrees=config.left_arm_config.use_degrees,
)
right_arm_config = OpenArmMiniConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
side="right",
use_degrees=config.right_arm_config.use_degrees,
)
self.left_arm = OpenArmMini(left_arm_config)
self.right_arm = OpenArmMini(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
}
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
action: RobotAction = {}
for k, v in self.left_arm.get_action().items():
action[f"left_{k}"] = v
for k, v in self.right_arm.get_action().items():
action[f"right_{k}"] = v
return action
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
if left_fb:
self.left_arm.send_feedback(left_fb)
if right_fb:
self.right_arm.send_feedback(right_fb)
@@ -0,0 +1,29 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
from ..openarm_mini import OpenArmMiniConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
@dataclass
class BiOpenArmMiniConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Mini teleoperators."""
left_arm_config: OpenArmMiniConfigBase
right_arm_config: OpenArmMiniConfigBase
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .bi_rebot_102_leader import BiRebotArm102Leader from .bi_rebot_102_leader import BiRebot102Leader
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"] __all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
@@ -18,16 +18,17 @@ import logging
from functools import cached_property from functools import cached_property
from lerobot.types import RobotAction from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiRebotArm102Leader(Teleoperator): class BiRebot102Leader(BimanualMixin, Teleoperator):
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader. """Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
@@ -35,10 +36,10 @@ class BiRebotArm102Leader(Teleoperator):
leader can teleoperate a bimanual reBot B601 follower. leader can teleoperate a bimanual reBot B601 follower.
""" """
config_class = BiRebotArm102LeaderConfig config_class = BiRebot102LeaderConfig
name = "bi_rebot_102_leader" name = "bi_rebot_102_leader"
def __init__(self, config: BiRebotArm102LeaderConfig): def __init__(self, config: BiRebot102LeaderConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
@@ -76,27 +77,6 @@ class BiRebotArm102Leader(Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected @check_if_not_connected
def get_action(self) -> RobotAction: def get_action(self) -> RobotAction:
action_dict = {} action_dict = {}
@@ -106,8 +86,3 @@ class BiRebotArm102Leader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.") raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader") @TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
@dataclass @dataclass
class BiRebotArm102LeaderConfig(TeleoperatorConfig): class BiRebot102LeaderConfig(TeleoperatorConfig):
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator.""" """Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
left_arm_config: RebotArm102LeaderConfig left_arm_config: RebotArm102LeaderConfig
@@ -17,7 +17,9 @@
import logging import logging
from functools import cached_property from functools import cached_property
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..so_leader import SOLeader, SOLeaderTeleopConfig from ..so_leader import SOLeader, SOLeaderTeleopConfig
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
@@ -26,7 +28,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BiSOLeader(Teleoperator): class BiSOLeader(BimanualMixin, Teleoperator):
""" """
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio [Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
""" """
@@ -67,33 +69,12 @@ class BiSOLeader(Teleoperator):
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
return {} return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None: def setup_motors(self) -> None:
self.left_arm.setup_motors() self.left_arm.setup_motors()
self.right_arm.setup_motors() self.right_arm.setup_motors()
@check_if_not_connected @check_if_not_connected
def get_action(self) -> dict[str, float]: def get_action(self) -> RobotAction:
action_dict = {} action_dict = {}
# Add "left_" prefix # Add "left_" prefix
@@ -109,8 +90,3 @@ class BiSOLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback # TODO: Implement force feedback
raise NotImplementedError raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
from .openarm_mini import OpenArmMini from .openarm_mini import OpenArmMini
__all__ = ["OpenArmMini", "OpenArmMiniConfig"] __all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
@@ -19,12 +19,21 @@ from dataclasses import dataclass
from ..config import TeleoperatorConfig from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass @dataclass
class OpenArmMiniConfig(TeleoperatorConfig): class OpenArmMiniConfigBase:
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms).""" """Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
port_right: str = "/dev/ttyUSB0" # Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
port_left: str = "/dev/ttyUSB1" port: str
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
# during readout. If `None`, no flipping is applied.
side: str | None = None
use_degrees: bool = True use_degrees: bool = True
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
pass
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Motors whose direction is inverted during readout # Per-side motor direction flips applied during readout.
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"] SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"] "left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
}
# Leader joint 6 maps to follower joint 7 and vice versa # Leader joint 6 follower joint 7 (symmetric — its own inverse).
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"} JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65 GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator): class OpenArmMini(Teleoperator):
""" """OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos. For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
""" """
config_class = OpenArmMiniConfig config_class = OpenArmMiniConfig
@@ -56,9 +56,12 @@ class OpenArmMini(Teleoperator):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
norm_mode_body = MotorNormMode.DEGREES norm_mode_body = MotorNormMode.DEGREES
motors = {
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body), "joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body), "joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body), "joint_3": Motor(3, "sts3215", norm_mode_body),
@@ -69,46 +72,15 @@ class OpenArmMini(Teleoperator):
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100), "gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
} }
motors_left = { self.bus = FeetechMotorsBus(
"joint_1": Motor(1, "sts3215", norm_mode_body), port=self.config.port,
"joint_2": Motor(2, "sts3215", norm_mode_body), motors=motors,
"joint_3": Motor(3, "sts3215", norm_mode_body), calibration=self.calibration,
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
) )
@property @property
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
# Right first, then left — matches the robot (BiOpenArmFollower) ordering return {f"{motor}.pos": float for motor in self.bus.motors}
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property @property
def feedback_features(self) -> dict[str, type]: def feedback_features(self) -> dict[str, type]:
@@ -116,14 +88,12 @@ class OpenArmMini(Teleoperator):
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self.bus_right.is_connected and self.bus_left.is_connected return self.bus.is_connected
@check_if_already_connected @check_if_already_connected
def connect(self, calibrate: bool = True) -> None: def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting right arm on {self.config.port_right}...") logger.info(f"Connecting arm on {self.config.port}...")
self.bus_right.connect() self.bus.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate: if calibrate:
self.calibrate() self.calibrate()
@@ -133,14 +103,14 @@ class OpenArmMini(Teleoperator):
@property @property
def is_calibrated(self) -> bool: def is_calibrated(self) -> bool:
return self.bus_right.is_calibrated and self.bus_left.is_calibrated return self.bus.is_calibrated
def calibrate(self) -> None: def calibrate(self) -> None:
""" """
Run calibration procedure for OpenArm Mini. Run calibration procedure for a single OpenArm Mini arm.
1. Disable torque 1. Disable torque
2. Ask user to position arms in hanging position with grippers closed 2. Ask user to position arm in hanging position with gripper closed
3. Set this as zero position via half-turn homing 3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions) 4. Interactive gripper calibration (open/close positions)
5. Save calibration 5. Save calibration
@@ -152,70 +122,51 @@ class OpenArmMini(Teleoperator):
) )
if user_input.strip().lower() != "c": if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}") logger.info(f"Using existing calibration for {self.id}")
cal_right = { self.bus.write_calibration(self.calibration)
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return return
logger.info(f"\nRunning calibration for {self}") logger.info(f"\nRunning calibration for {self}")
self._calibrate_arm("right", self.bus_right) self.bus.disable_torque()
self._calibrate_arm("left", self.bus_left)
self._save_calibration() logger.info("Setting Phase to 12 for all motors...")
print(f"\nCalibration complete and saved to {self.calibration_fpath}") for motor in self.bus.motors:
self.bus.write("Phase", motor, 12)
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None: for motor in self.bus.motors:
"""Calibrate a single arm with Feetech motors.""" self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input( input(
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n" "\nCalibration: Zero Position\n"
"Position the arm in the following configuration:\n" "Position the arm in the following configuration:\n"
" - Arm hanging straight down\n" " - Arm hanging straight down\n"
" - Gripper closed\n" " - Gripper closed\n"
"Press ENTER when ready..." "Press ENTER when ready..."
) )
homing_offsets = bus.set_half_turn_homings() homing_offsets = self.bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.") logger.info("Arm zero position set.")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n") print("\nSetting motor ranges\n")
if self.calibration is None: if self.calibration is None:
self.calibration = {} self.calibration = {}
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model] motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
max_res = motor_resolution - 1 max_res = motor_resolution - 1
for motor_name, motor in bus.motors.items(): for motor_name, motor in self.bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper": if motor_name == "gripper":
input( input(
f"\nGripper Calibration ({arm_name.upper()} arm)\n" "\nGripper Calibration\n"
f"Step 1: CLOSE the gripper fully\n" "Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..." "Press ENTER when gripper is closed..."
) )
closed_pos = bus.read("Present_Position", motor_name, normalize=False) closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}") logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...") input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = bus.read("Present_Position", motor_name, normalize=False) open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}") logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos: if closed_pos < open_pos:
@@ -228,16 +179,16 @@ class OpenArmMini(Teleoperator):
drive_mode = 1 drive_mode = 1
logger.info( logger.info(
f" {prefixed_name}: range set to [{range_min}, {range_max}] " f" {motor_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})" f"(0=closed, 100=open, drive_mode={drive_mode})"
) )
else: else:
range_min = 0 range_min = 0
range_max = max_res range_max = max_res
drive_mode = 0 drive_mode = 0
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)") logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[prefixed_name] = MotorCalibration( self.calibration[motor_name] = MotorCalibration(
id=motor.id, id=motor.id,
drive_mode=drive_mode, drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name], homing_offset=homing_offsets[motor_name],
@@ -245,108 +196,68 @@ class OpenArmMini(Teleoperator):
range_max=range_max, range_max=range_max,
) )
cal_for_bus = { self.bus.write_calibration(self.calibration)
k.replace(f"{arm_name}_", ""): v self._save_calibration()
for k, v in self.calibration.items() print(f"\nCalibration complete and saved to {self.calibration_fpath}")
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None: def configure(self) -> None:
self.bus_right.disable_torque() self.bus.disable_torque()
self.bus_right.configure_motors() self.bus.configure_motors()
for motor in self.bus_right.motors: for motor in self.bus.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value) self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None: def setup_motors(self) -> None:
print("\nSetting up RIGHT arm motors...") for motor in reversed(self.bus.motors):
for motor in reversed(self.bus_right.motors): input(f"Connect the controller board to the '{motor}' motor only and press enter.")
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.") self.bus.setup_motor(motor)
self.bus_right.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected @check_if_not_connected
def get_action(self) -> RobotAction: def get_action(self) -> RobotAction:
"""Get current action from both arms (read positions from all motors).""" """Get current action (read positions from all motors)."""
start = time.perf_counter() start = time.perf_counter()
right_positions = self.bus_right.sync_read("Present_Position") positions = self.bus.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa. # Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
# Per-side direction flip is applied based on the configured `side`.
action: dict[str, Any] = {} action: dict[str, Any] = {}
for motor, val in right_positions.items(): for motor, val in positions.items():
target = JOINT_REMAP.get(motor, motor) target = JOINT_REMAP.get(motor, motor)
if motor == "gripper": if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65° # Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else: else:
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
for motor, val in left_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3 dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms") logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action return action
def enable_torque(self) -> None: def enable_torque(self) -> None:
"""Enable torque on both arms for position control.""" self.bus.enable_torque()
self.bus_right.enable_torque()
self.bus_left.enable_torque()
def disable_torque(self) -> None: def disable_torque(self) -> None:
"""Disable torque on both arms for free movement.""" self.bus.disable_torque()
self.bus_right.disable_torque()
self.bus_left.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None: def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic).""" """Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
right_goals: dict[str, float] = {} goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
for key, val in positions.items(): for key, val in positions.items():
if not key.endswith(".pos"): if not key.endswith(".pos"):
continue continue
motor_name = key.removesuffix(".pos") base = key.removesuffix(".pos")
if motor_name.startswith("right_"): # JOINT_REMAP is symmetric (its own inverse).
base = motor_name.removeprefix("right_") target = JOINT_REMAP.get(base, base)
# Reverse remap: follower joint_7 → leader joint_6 and vice versa if base == "gripper":
target = JOINT_REMAP_REVERSE.get(base, base) # Convert robot degrees to teleop 0-100: 0°→0, -65°→100
if base == "gripper": goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100 else:
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES # Un-flip using the ORIGINAL motor name (target = leader motor)
else: goals[target] = -val if target in self._motors_to_flip else val
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
if right_goals: if goals:
self.bus_right.sync_write("Goal_Position", right_goals) self.bus.sync_write("Goal_Position", goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
@check_if_not_connected @check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None: def send_feedback(self, feedback: dict[str, float]) -> None:
@@ -354,6 +265,5 @@ class OpenArmMini(Teleoperator):
@check_if_not_connected @check_if_not_connected
def disconnect(self) -> None: def disconnect(self) -> None:
self.bus_right.disconnect() self.bus.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.") logger.info(f"{self} disconnected.")
+6 -2
View File
@@ -99,14 +99,18 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .openarm_mini import OpenArmMini from .openarm_mini import OpenArmMini
return OpenArmMini(config) return OpenArmMini(config)
elif config.type == "bi_openarm_mini":
from .bi_openarm_mini import BiOpenArmMini
return BiOpenArmMini(config)
elif config.type == "rebot_102_leader": elif config.type == "rebot_102_leader":
from .rebot_102_leader import RebotArm102Leader from .rebot_102_leader import RebotArm102Leader
return RebotArm102Leader(config) return RebotArm102Leader(config)
elif config.type == "bi_rebot_102_leader": elif config.type == "bi_rebot_102_leader":
from .bi_rebot_102_leader import BiRebotArm102Leader from .bi_rebot_102_leader import BiRebot102Leader
return BiRebotArm102Leader(config) return BiRebot102Leader(config)
else: else:
try: try:
return cast("Teleoperator", make_device_from_device_class(config)) return cast("Teleoperator", make_device_from_device_class(config))
+63
View File
@@ -0,0 +1,63 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
class BimanualMixin:
"""Lifecycle delegation for bimanual robots and teleoperators.
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
their own ``__init__``. They retain ownership of feature dicts and the
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
``send_feedback``), which vary per-embodiment.
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
take precedence in the MRO::
class BiFooFollower(BimanualMixin, Robot): ...
"""
left_arm: Any
right_arm: Any
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
+8 -2
View File
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
This function uses `importlib.metadata` to find packages installed in the environment This function uses `importlib.metadata` to find packages installed in the environment
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_', (including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them. 'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
""" """
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_") prefixes = (
"lerobot_robot_",
"lerobot_camera_",
"lerobot_teleoperator_",
"lerobot_policy_",
"lerobot_env_",
)
imported: list[str] = [] imported: list[str] = []
failed: list[str] = [] failed: list[str] = []
+73
View File
@@ -28,6 +28,7 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])") pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import pyarrow.parquet as pq # noqa: E402 import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402 from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
@@ -344,6 +345,78 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
assert len(dataset) == 24 assert len(dataset) == 24
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
for ep, length in enumerate(episode_lengths):
episode_index += [ep] * length
frame_index += list(range(length))
timestamp += [round(i / fps, 6) for i in range(length)]
task_index += [0] * length
subtask_index += [0] * length # legacy column the writer must drop
pd.DataFrame(
{
"episode_index": episode_index,
"frame_index": frame_index,
"timestamp": timestamp,
"task_index": task_index,
"subtask_index": subtask_index,
}
).to_parquet(data_dir / "file-000.parquet", index=False)
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
write_tasks(tasks_df, root)
write_json(
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
root / "meta" / "info.json",
)
return root
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
"""Rewriting a packed shard must keep one row group per episode, not collapse
every episode into a single giant row group."""
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
shard = root / "data" / "chunk-000" / "file-000.parquet"
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
staging_dir = tmp_path / "stage"
for ep in range(len(episode_lengths)):
_stage_episode(
staging_dir,
ep,
plan=[
{
"role": "assistant",
"content": f"subtask for ep {ep}",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
)
records = list(iter_episodes(root))
LanguageColumnsWriter().write_all(records, staging_dir, root)
# One row group per episode, with row counts matching the episode lengths.
md = pq.ParquetFile(shard).metadata
assert md.num_row_groups == len(episode_lengths)
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
# Language columns are still present after the per-episode rewrite.
table = pq.read_table(shard)
assert "language_persistent" in table.column_names
assert "language_events" in table.column_names
def test_speech_atom_shape_matches_plan_spec() -> None: def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!") atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant" assert atom["role"] == "assistant"
+55
View File
@@ -32,6 +32,26 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_REPO_ID
def assert_data_shards_one_row_group_per_episode(root):
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
import pyarrow.parquet as pq
shards = sorted((root / "data").rglob("*.parquet"))
assert shards, f"no data shards found under {root}/data"
n_episodes = 0
for shard in shards:
pf = pq.ParquetFile(shard)
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
assert pf.metadata.num_row_groups == len(set(episodes)), shard
for i in range(pf.metadata.num_row_groups):
rg_episodes = set(
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
)
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
n_episodes += len(set(episodes))
return n_episodes
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames): def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated.""" """Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, ( assert aggr_ds.num_episodes == expected_episodes, (
@@ -566,6 +586,41 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
) )
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
Covers both the non-image (``df.to_parquet``) and image
(``to_parquet_with_hf_images``) write branches, including the merge-into-
existing-file branch via a low file-size threshold that forces packing.
"""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "rg_0",
repo_id=f"{DUMMY_REPO_ID}_rg_0",
total_episodes=3,
total_frames=60,
use_videos=use_videos,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "rg_1",
repo_id=f"{DUMMY_REPO_ID}_rg_1",
total_episodes=4,
total_frames=80,
use_videos=use_videos,
)
aggr_root = tmp_path / "rg_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
aggr_root=aggr_root,
)
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema. """Test aggregation of image-based datasets preserves HuggingFace Image schema.
+16 -1
View File
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
from lerobot.transforms import ImageTransforms, ImageTransformsConfig from lerobot.transforms import ImageTransforms, ImageTransformsConfig
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
from lerobot.utils.feature_utils import hw_to_dataset_features from lerobot.utils.feature_utils import hw_to_dataset_features
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel
@@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error():
) )
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
)
dataset_copy = empty_lerobot_dataset_factory(
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
)
original_shape = dataset.meta.info.features["state"]["shape"]
dataset_copy.meta.info.features["state"]["shape"] = (999,)
assert dataset.meta.info.features["state"]["shape"] == original_shape
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
+39
View File
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import (
MultiAdamConfig, MultiAdamConfig,
SGDConfig, SGDConfig,
load_optimizer_state, load_optimizer_state,
load_optimizer_state_dict,
save_optimizer_state, save_optimizer_state,
) )
from lerobot.utils.constants import ( from lerobot.utils.constants import (
@@ -65,6 +66,44 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
def test_save_and_load_fsdp_optimizer_state_dict_roundtrip(tmp_path):
"""The FSDP full optimizer state dict is keyed by parameter FQNs (dotted strings), not the
integer indices of the single-GPU path. Verify it survives the safetensors save -> read
round-trip used by the FSDP save/resume path (save_optimizer_state(optim_state_dict=...) then
load_optimizer_state_dict), which the flatten/unflatten "/" separator must not corrupt."""
full_osd = {
"state": {
"model.layers.0.weight": {
"step": torch.tensor(3.0),
"exp_avg": torch.randn(4, 4),
"exp_avg_sq": torch.randn(4, 4),
},
"model.layers.0.bias": {
"step": torch.tensor(3.0),
"exp_avg": torch.randn(4),
"exp_avg_sq": torch.randn(4),
},
},
"param_groups": [
{"lr": 1e-4, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.0, "params": [0, 1]}
],
}
save_optimizer_state(
torch.optim.Adam([torch.nn.Parameter(torch.randn(1))]), tmp_path, optim_state_dict=full_osd
)
assert (tmp_path / OPTIMIZER_STATE).is_file()
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
loaded = load_optimizer_state_dict(tmp_path)
# FQN keys must be preserved verbatim (not int-cast, not split on their dots).
assert set(loaded["state"].keys()) == set(full_osd["state"].keys())
for fqn, sub in full_osd["state"].items():
for k, v in sub.items():
torch.testing.assert_close(loaded["state"][fqn][k], v)
assert loaded["param_groups"] == full_osd["param_groups"]
@pytest.fixture @pytest.fixture
def base_params_dict(): def base_params_dict():
return { return {
+24
View File
@@ -23,6 +23,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from packaging import version from packaging import version
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -300,6 +301,29 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
policy_cls = get_policy_class("act")
policy_cfg = make_policy_config("act")
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
policy.to(policy_cfg.device)
save_dir = tmp_path / "fsdp_state_dict"
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
# A single, unsharded safetensors file (no sharded set + index).
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("multikey", [True, False]) @pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool): def test_multikey_construction(multikey: bool):
""" """
+21 -3
View File
@@ -2370,14 +2370,32 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features( out = aggregate_pipeline_dataset_features(
pipeline=rp, pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # expect "image" dtype use_videos=False, # images kept, stored as "image" dtype
patterns=None, patterns=None,
) )
key = f"{OBS_IMAGES}.back" key = f"{OBS_IMAGES}.back"
key_front = f"{OBS_IMAGES}.front" key_front = f"{OBS_IMAGES}.front"
assert key not in out assert key in out
assert key_front not in out assert key_front in out
assert out[key]["dtype"] == "image"
assert out[key_front]["dtype"] == "image"
assert out[key]["shape"] == initial["back"]
def test_aggregate_images_excluded():
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
initial = {"back": (480, 640, 3)}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
exclude_images=True,
patterns=None,
)
assert f"{OBS_IMAGES}.back" not in out
assert f"{OBS_IMAGES}.front" not in out
def test_aggregate_images_when_use_videos_true(): def test_aggregate_images_when_use_videos_true():
+3 -3
View File
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
from lerobot.teleoperators.rebot_102_leader import ( from lerobot.teleoperators.rebot_102_leader import (
RebotArm102Leader, RebotArm102Leader,
RebotArm102LeaderConfig, RebotArm102LeaderConfig,
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
def test_bimanual_prefixes_features(): def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None): with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebotArm102LeaderConfig( cfg = BiRebot102LeaderConfig(
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"), left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"), right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
) )
teleop = BiRebotArm102Leader(cfg) teleop = BiRebot102Leader(cfg)
assert any(k.startswith("left_") for k in teleop.action_features) assert any(k.startswith("left_") for k in teleop.action_features)
assert any(k.startswith("right_") for k in teleop.action_features) assert any(k.startswith("right_") for k in teleop.action_features)
assert "left_gripper.pos" in teleop.action_features assert "left_gripper.pos" in teleop.action_features
+108 -13
View File
@@ -58,7 +58,46 @@ def download_dataset(repo_id, episodes):
print(f"Dataset {repo_id} downloaded successfully") print(f"Dataset {repo_id} downloaded successfully")
def run_accelerate_training(config_args, num_processes=4, temp_dir=None): def _write_multi_gpu_config(f, num_processes):
f.write("compute_environment: LOCAL_MACHINE\n")
f.write("distributed_type: MULTI_GPU\n")
f.write("mixed_precision: 'no'\n")
f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("downcast_bf16: 'no'\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
def _write_fsdp_config(f, num_processes):
# FSDP1 with FULL_SHARD (ZeRO-3-equivalent) and FULL_STATE_DICT, matching
# docs/source/multi_gpu_training.mdx. ACT's repeated transformer blocks are the wrap units;
# fsdp_use_orig_params is required because LeRobot builds the optimizer before prepare().
f.write("compute_environment: LOCAL_MACHINE\n")
f.write("distributed_type: FSDP\n")
f.write("mixed_precision: 'no'\n")
f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
f.write("fsdp_config:\n")
f.write(" fsdp_version: 1\n")
f.write(" fsdp_sharding_strategy: FULL_SHARD\n")
f.write(" fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n")
f.write(" fsdp_transformer_layer_cls_to_wrap: ACTEncoderLayer,ACTDecoderLayer\n")
f.write(" fsdp_use_orig_params: true\n")
f.write(" fsdp_state_dict_type: FULL_STATE_DICT\n")
def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distributed_type="MULTI_GPU"):
""" """
Helper function to run training with accelerate launch. Helper function to run training with accelerate launch.
@@ -66,6 +105,7 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
config_args: List of config arguments to pass to lerobot_train.py config_args: List of config arguments to pass to lerobot_train.py
num_processes: Number of processes (GPUs) to use num_processes: Number of processes (GPUs) to use
temp_dir: Temporary directory for outputs temp_dir: Temporary directory for outputs
distributed_type: "MULTI_GPU" (DDP) or "FSDP" selects the generated accelerate config.
Returns: Returns:
subprocess.CompletedProcess result subprocess.CompletedProcess result
@@ -75,18 +115,10 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
# Write YAML config # Write YAML config
with open(config_path, "w") as f: with open(config_path, "w") as f:
f.write("compute_environment: LOCAL_MACHINE\n") if distributed_type == "FSDP":
f.write("distributed_type: MULTI_GPU\n") _write_fsdp_config(f, num_processes)
f.write("mixed_precision: 'no'\n") else:
f.write(f"num_processes: {num_processes}\n") _write_multi_gpu_config(f, num_processes)
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("downcast_bf16: 'no'\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
cmd = [ cmd = [
"accelerate", "accelerate",
@@ -211,3 +243,66 @@ class TestMultiGPUTraining:
# Verify optimizer state exists # Verify optimizer state exists
optimizer_state = training_state_dir / "optimizer_state.safetensors" optimizer_state = training_state_dir / "optimizer_state.safetensors"
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}" assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
def test_fsdp_optimizer_save_and_resume(self):
"""
Test that FSDP saves the (gathered) optimizer state and can resume from it.
Trains a few steps under FSDP, verifies the gathered optimizer state is written next to the
rest of the training state, then resumes from the checkpoint for more steps and checks it
completes without shape/key errors in the FSDP optimizer load path.
"""
# Pre-download dataset to avoid race conditions
download_dataset("lerobot/pusht", episodes=[0])
with tempfile.TemporaryDirectory() as temp_dir:
output_dir = Path(temp_dir) / "outputs"
config_args = [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--policy.push_to_hub=false",
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=10",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
"--num_workers=0",
]
result = run_accelerate_training(
config_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
)
assert result.returncode == 0, (
f"FSDP training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
)
# The gathered optimizer state must be written under FSDP (proves the save collective ran),
# in the same safetensors format as single-GPU training.
training_state_dir = output_dir / "checkpoints" / "last" / "training_state"
optimizer_state = training_state_dir / "optimizer_state.safetensors"
optimizer_param_groups = training_state_dir / "optimizer_param_groups.json"
assert optimizer_state.exists(), f"FSDP optimizer state not saved in {training_state_dir}"
assert optimizer_param_groups.exists(), (
f"FSDP optimizer param groups not saved in {training_state_dir}"
)
# Resume from the checkpoint for more steps. A successful run proves load_fsdp_optimizer
# accepts the saved state and reshards it without shape/key errors.
resume_config = output_dir / "checkpoints" / "last" / "pretrained_model" / "train_config.json"
resume_args = [
f"--config_path={resume_config}",
"--resume=true",
"--steps=20",
]
resume_result = run_accelerate_training(
resume_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
)
assert resume_result.returncode == 0, (
f"FSDP resume failed:\nSTDOUT:\n{resume_result.stdout}\n\nSTDERR:\n{resume_result.stderr}"
)
assert "End of training" in resume_result.stdout or "End of training" in resume_result.stderr
+15
View File
@@ -136,3 +136,18 @@ def test_save_load_training_state(tmp_path, optimizer, scheduler):
assert loaded_step == 10 assert loaded_step == 10
assert loaded_optimizer is optimizer assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler assert loaded_scheduler is scheduler
def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
# FSDP loads optimizer separately (after accelerator.prepare)
# load_training_state(load_optimizer=False) must restore step + scheduler but leave the
# optimizer untouched and never touch the on-disk optimizer state.
save_training_state(tmp_path, 10, optimizer, scheduler)
with patch("lerobot.common.train_utils.load_optimizer_state") as mock_load_optimizer_state:
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
tmp_path, optimizer, scheduler, load_optimizer=False
)
mock_load_optimizer_state.assert_not_called()
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler
Generated
+949 -900
View File
File diff suppressed because it is too large Load Diff