Compare commits

..

6 Commits

Author SHA1 Message Date
Maxime Ellerbach 24a43c8180 refactored initial implementation to use torch fsdp api and adding new tests 2026-06-18 12:15:03 +00:00
Maxime Ellerbach 2e0deff3ab fixing final upload to hub 2026-06-17 09:42:05 +00:00
Maxime Ellerbach b42d124007 cleanup 2026-06-15 14:50:23 +00:00
Maxime Ellerbach 3ce50c3468 adding a test for the fsdp checkpoint path 2026-06-15 14:36:22 +00:00
Maxime Ellerbach 44fd3c0a0e adding docs for FSDP 2026-06-15 14:15:09 +00:00
Maxime Ellerbach 0483afc743 feat(train): FSDP checkpoint saving 2026-06-15 14:03:17 +00:00
50 changed files with 1772 additions and 2035 deletions
+8 -8
View File
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
**Compatible teleoperators:**
- `bi_openarm_mini` - Bimanual OpenArm Mini
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \
+1 -1
View File
@@ -117,7 +117,7 @@ lerobot-rollout \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=bi_openarm_mini \
--teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt"
```
+55
View File
@@ -113,6 +113,61 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
--policy=act
```
## Training Large Models with FSDP
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
```bash
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=<your_policy> \
--output_dir=outputs/train/my_policy_fsdp
```
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: bf16
num_machines: 1
num_processes: 4
fsdp_config:
fsdp_version: 1
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
fsdp_state_dict_type: FULL_STATE_DICT
```
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
optimizer before `accelerator.prepare()`.
### FSDP checkpoints
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two things to look out for:
- **Checkpoints store fp32 weights.** Under mixed precision (`bf16`/`fp16`) FSDP keeps an fp32 master
copy, and the checkpoint saves it (~2× the bf16 size on disk) so training can resume consistently
with the fp32 optimizer state; `from_pretrained` casts back to the policy dtype on load. FSDP-specific
caveat: an fp32 checkpoint is materialized in full precision on the target device _before_ casting,
so loading it for inference on a tight GPU can OOM even when the bf16 model would fit — load on CPU
first, or cast `model.safetensors` to the deployment dtype offline.
- The sharded optimizer state is gathered into a full (world-size-independent) state dict and saved
alongside the model in the same `optimizer_state.safetensors` / `optimizer_param_groups.json`
format as single-GPU training, so **resume-from-checkpoint is supported** with `--resume=true`.
Resume reshards both the model and the optimizer state to the _current_ FSDP topology, so you can
resume an FSDP checkpoint on a different number of GPUs. Note that the data sampler is only
sample-exact when the world size and batch size match the original run (a warning is logged
otherwise); the optimizer/model state itself is unaffected.
## Notes
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
@@ -1,79 +0,0 @@
#!/usr/bin/env python
"""Convert a legacy LeRobot checkpoint to the current processor-pipeline format.
Older hub checkpoints (e.g. ``lerobot/act_aloha_sim_insertion_human``) bake
normalization stats into the model weights and do not ship
``policy_preprocessor.json`` / ``policy_postprocessor.json``. Current ``main``
loads those processor configs from the checkpoint, so eval/rollout fail with
``FileNotFoundError: Could not find 'policy_preprocessor.json'``.
This script rebuilds the processors from the training dataset's stats and saves
a pipeline-format checkpoint locally that ``lerobot-eval`` can consume directly.
Usage:
python examples/onnx/convert_legacy_checkpoint.py \
--policy-path=lerobot/act_aloha_sim_insertion_human \
--dataset-repo-id=lerobot/aloha_sim_insertion_human \
--output-dir=outputs/converted/act_aloha_sim_insertion_human
Then:
lerobot-eval \
--policy.path=outputs/converted/act_aloha_sim_insertion_human \
--env.type=aloha --env.task=AlohaInsertion-v0 \
--eval.batch_size=10 --eval.n_episodes=50 \
--eval.use_async_envs=false --policy.device=cuda
"""
import argparse
from pathlib import Path
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--policy-path", required=True, help="Legacy checkpoint repo id or local dir")
parser.add_argument(
"--dataset-repo-id",
required=True,
help="Training dataset repo id, used only for normalization stats",
)
parser.add_argument("--output-dir", required=True, help="Where to save the converted checkpoint")
parser.add_argument("--device", default="cpu", help="Device for building the policy (cpu is fine)")
args = parser.parse_args()
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
print(f"[1/4] Loading dataset stats from '{args.dataset_repo_id}' (metadata only)...")
ds_meta = LeRobotDatasetMetadata(args.dataset_repo_id)
print(f"[2/4] Loading policy weights from '{args.policy_path}'...")
cfg = PreTrainedConfig.from_pretrained(args.policy_path)
cfg.pretrained_path = args.policy_path
cfg.device = args.device
policy = make_policy(cfg, ds_meta=ds_meta)
print("[3/4] Building processors from dataset stats...")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
dataset_stats=ds_meta.stats,
)
print(f"[4/4] Saving pipeline-format checkpoint to '{out}'...")
policy.save_pretrained(out)
preprocessor.save_pretrained(out, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json")
postprocessor.save_pretrained(out, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json")
print(f"\nDone. Converted checkpoint at: {out}")
print("Eval it with --policy.path=" + str(out))
if __name__ == "__main__":
main()
-179
View File
@@ -1,179 +0,0 @@
#!/usr/bin/env python
"""Evaluate an ACT policy in sim with either the PyTorch or ONNX network.
The ONNX backend swaps only ``policy.model`` (ResNet + transformer + action head)
with an onnxruntime session. Everything else - the LeRobot processor pipeline
(normalization), the action queue, and the gym env - is identical, so any
difference in success rate is attributable to the network backend alone.
Run both backends with the same seed to compare:
python examples/onnx/eval_act_onnx.py \
--policy-path=lerobot/act_aloha_sim_transfer_cube_human \
--task=AlohaTransferCube-v0 \
--backend=torch --n-episodes=50 --batch-size=10 --device=cuda
python examples/onnx/eval_act_onnx.py \
--policy-path=lerobot/act_aloha_sim_transfer_cube_human \
--task=AlohaTransferCube-v0 \
--onnx=outputs/onnx/act_transfer_cube.onnx \
--backend=onnx --n-episodes=50 --batch-size=10 --device=cuda
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from torch import nn
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.scripts.lerobot_eval import eval_policy
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import set_seed
class ONNXACTModel(nn.Module):
"""Drop-in replacement for ``ACTPolicy.model`` backed by onnxruntime."""
def __init__(self, onnx_path: str, image_keys: list[str], has_state: bool, has_env_state: bool, device: str):
super().__init__()
import onnxruntime as ort
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if str(device).startswith("cuda")
else ["CPUExecutionProvider"]
)
so = ort.SessionOptions()
so.log_severity_level = 3
self.sess = ort.InferenceSession(onnx_path, sess_options=so, providers=providers)
self.image_keys = image_keys
self.has_state = has_state
self.has_env_state = has_env_state
print(f"[onnx] providers in use: {self.sess.get_providers()}")
def forward(self, batch: dict):
if self.has_state:
state = batch[OBS_STATE]
else:
state = batch[OBS_ENV_STATE]
ref = state
ort_inputs = {"state": state.detach().cpu().numpy().astype(np.float32)}
images = batch[OBS_IMAGES]
for i, img in enumerate(images):
ort_inputs[f"image_{i}"] = img.detach().cpu().numpy().astype(np.float32)
out = self.sess.run(None, ort_inputs)[0]
actions = torch.from_numpy(out).to(ref.device, dtype=ref.dtype)
return actions, None
def load_stats_from_checkpoint(policy_path: str, input_features, output_features) -> dict:
"""Recover MEAN_STD stats baked into a legacy ACT checkpoint's safetensors buffers.
Legacy checkpoints store normalization as buffers like
``normalize_inputs.buffer_observation_state.{mean,std}``. We map those back to
feature names so we can rebuild the processor pipeline without the dataset.
"""
from safetensors.torch import load_file
p = Path(policy_path)
if p.is_dir():
st_path = p / "model.safetensors"
else:
from huggingface_hub import hf_hub_download
st_path = Path(hf_hub_download(policy_path, "model.safetensors"))
sd = load_file(str(st_path))
stats: dict = {}
for feat in list(input_features) + list(output_features):
buf = "buffer_" + feat.replace(".", "_")
for prefix in ("normalize_inputs", "normalize_targets", "unnormalize_outputs"):
mkey, skey = f"{prefix}.{buf}.mean", f"{prefix}.{buf}.std"
if mkey in sd and skey in sd:
stats[feat] = {"mean": sd[mkey].numpy(), "std": sd[skey].numpy()}
break
return stats
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--policy-path", required=True)
parser.add_argument("--task", required=True, help="e.g. AlohaTransferCube-v0")
parser.add_argument("--env-type", default="aloha")
parser.add_argument("--backend", choices=["torch", "onnx"], default="torch")
parser.add_argument("--onnx", default=None, help="Path to .onnx (required for --backend=onnx)")
parser.add_argument("--n-episodes", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=10)
parser.add_argument("--device", default="cuda")
parser.add_argument("--seed", type=int, default=1000)
args = parser.parse_args()
if args.backend == "onnx" and not args.onnx:
raise SystemExit("--backend=onnx requires --onnx=<path>")
device = "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
set_seed(args.seed)
print(f"[1/4] Loading ACT policy from '{args.policy_path}'...")
policy = ACTPolicy.from_pretrained(args.policy_path)
policy.config.device = device
policy.eval()
policy.to(device)
cfg = policy.config
if args.backend == "onnx":
image_keys = list(cfg.image_features)
has_state = cfg.robot_state_feature is not None
has_env_state = cfg.env_state_feature is not None
print(f"[2/4] Swapping policy.model with ONNX backend ({args.onnx})")
policy.model = ONNXACTModel(args.onnx, image_keys, has_state, has_env_state, device)
policy.to(device)
else:
print("[2/4] Using PyTorch backend")
print("[3/4] Building processors and environment...")
stats = load_stats_from_checkpoint(args.policy_path, cfg.input_features, cfg.output_features)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg,
dataset_stats=stats,
preprocessor_overrides={"device_processor": {"device": device}},
)
env_cfg = make_env_config(args.env_type, task=args.task)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=env_cfg, policy_cfg=cfg)
env_groups = make_env(env_cfg, n_envs=args.batch_size, use_async_envs=False)
# make_env returns {task_group: {idx: VectorEnv}}; grab the single env.
first_group = next(iter(env_groups.values()))
env = next(iter(first_group.values()))
print(f"[4/4] Evaluating backend='{args.backend}' for {args.n_episodes} episodes (seed={args.seed})...")
with torch.no_grad():
info = eval_policy(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=args.n_episodes,
start_seed=args.seed,
)
agg = info["aggregated"]
print("\n==== RESULT ====")
print(f"backend : {args.backend}")
print(f"task : {args.task}")
print(f"n_episodes : {args.n_episodes}")
print(f"pc_success : {agg['pc_success']:.1f}%")
print(f"avg_max_reward: {agg['avg_max_reward']:.4f}")
print(f"eval_ep_s : {agg['eval_ep_s']:.2f}s")
env.close()
if __name__ == "__main__":
main()
-133
View File
@@ -1,133 +0,0 @@
#!/usr/bin/env python
"""Export an ACT policy's network to ONNX and verify numerical parity.
Only the inference network is exported (ResNet backbone + transformer enc/dec +
action head). The VAE encoder is training-only and the inference latent is zeros,
so the exported graph is a pure function of (state, images) -> action_chunk.
Normalization stays in the LeRobot processor pipeline (outside ONNX).
Usage:
python examples/onnx/export_act.py \
--policy-path=outputs/converted/act_aloha_sim_transfer_cube_human \
--output=outputs/onnx/act_transfer_cube.onnx
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from torch import nn
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class ACTExportWrapper(nn.Module):
"""Tensor-in/tensor-out wrapper around ACT's inference network."""
def __init__(self, model: nn.Module, image_keys: list[str], has_state: bool, has_env_state: bool):
super().__init__()
self.model = model
self.image_keys = image_keys
self.has_state = has_state
self.has_env_state = has_env_state
def forward(self, state: torch.Tensor, *images: torch.Tensor) -> torch.Tensor:
batch: dict = {}
if self.has_state:
batch[OBS_STATE] = state
if self.has_env_state:
# Convention: when env_state is used it is passed as `state`.
batch[OBS_ENV_STATE] = state
batch[OBS_IMAGES] = list(images)
actions, _ = self.model(batch)
return actions
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--policy-path", required=True, help="Converted ACT checkpoint dir or repo id")
parser.add_argument("--output", required=True, help="Output .onnx path")
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--atol", type=float, default=1e-3)
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
print(f"[1/4] Loading ACT policy from '{args.policy_path}'...")
policy = ACTPolicy.from_pretrained(args.policy_path)
policy.eval()
policy.to(args.device)
cfg = policy.config
image_keys = list(cfg.image_features)
has_state = cfg.robot_state_feature is not None
has_env_state = cfg.env_state_feature is not None
state_dim = (cfg.robot_state_feature or cfg.env_state_feature).shape[0]
print(f" image_keys={image_keys} state_dim={state_dim} "
f"chunk_size={cfg.chunk_size} action_dim={cfg.action_feature.shape[0]}")
wrapper = ACTExportWrapper(policy.model, image_keys, has_state, has_env_state).eval().to(args.device)
# Build example inputs (batch size 1) from the config feature shapes.
state_example = torch.randn(1, state_dim, device=args.device)
image_examples = [
torch.rand(1, *cfg.image_features[k].shape, device=args.device) for k in image_keys
]
example_inputs = (state_example, *image_examples)
input_names = ["state"] + [f"image_{i}" for i in range(len(image_keys))]
output_names = ["action_chunk"]
dynamic_axes = {name: {0: "batch"} for name in input_names + output_names}
print(f"[2/4] Exporting to ONNX (opset {args.opset}) -> {out}")
torch.onnx.export(
wrapper,
example_inputs,
str(out),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=args.opset,
do_constant_folding=True,
dynamo=False,
)
print("[3/4] Running parity check (torch vs onnxruntime)...")
import onnxruntime as ort
providers = ["CPUExecutionProvider"]
so = ort.SessionOptions()
so.log_severity_level = 3
sess = ort.InferenceSession(str(out), sess_options=so, providers=providers)
# Fresh random inputs for the check.
state_check = torch.randn(2, state_dim, device=args.device)
image_check = [torch.rand(2, *cfg.image_features[k].shape, device=args.device) for k in image_keys]
with torch.no_grad():
torch_out = wrapper(state_check, *image_check).cpu().numpy()
ort_inputs = {"state": state_check.cpu().numpy()}
for i, img in enumerate(image_check):
ort_inputs[f"image_{i}"] = img.cpu().numpy()
ort_out = sess.run(None, ort_inputs)[0]
max_abs = float(np.max(np.abs(torch_out - ort_out)))
mean_abs = float(np.mean(np.abs(torch_out - ort_out)))
print(f" shapes: torch={torch_out.shape} onnx={ort_out.shape}")
print(f" max_abs_diff={max_abs:.3e} mean_abs_diff={mean_abs:.3e} (atol={args.atol:.0e})")
ok = max_abs <= args.atol
print(f"[4/4] Parity: {'PASS' if ok else 'FAIL'}")
if not ok:
raise SystemExit(f"Parity check failed: max_abs_diff {max_abs:.3e} > atol {args.atol:.0e}")
print(f"\nDone. ONNX model at: {out}")
if __name__ == "__main__":
main()
@@ -54,7 +54,6 @@ from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
@@ -275,11 +274,12 @@ class LanguageColumnsWriter:
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
# Re-emit one row group per episode (a bulk pq.write_table would collapse
# them into one). Write to a sibling tmp path and atomically rename so a
# crash mid-write can't leave a half-written shard.
# Atomic replace: write to a sibling tmp path and rename so a crash
# mid-write can't leave a half-written shard that ``pq.read_table``
# would then fail to open. ``Path.replace`` is atomic on POSIX +
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp")
write_table_one_row_group_per_episode(new_table, tmp_path)
pq.write_table(new_table, tmp_path)
tmp_path.replace(path)
def _materialize_table(
+83 -5
View File
@@ -21,6 +21,7 @@ from torch.optim.lr_scheduler import LRScheduler
from lerobot.configs.train import TrainPipelineConfig
from lerobot.optim import (
load_optimizer_state,
load_optimizer_state_dict,
load_scheduler_state,
save_optimizer_state,
save_scheduler_state,
@@ -98,6 +99,8 @@ def save_checkpoint(
postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
model_state_dict: dict | None = None,
optim_state_dict: dict | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -127,9 +130,18 @@ def save_checkpoint(
resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded).
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
cross-rank collective and passes it here so rank 0 can write it directly. It holds
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
optim_state_dict: Pre-gathered full (unsharded) optimizer state dict. Required under FSDP
(gathered alongside `model_state_dict` via `gather_fsdp_state_dicts`); saved in the same
safetensors format as the single-GPU path. When None, `optimizer.state_dict()` is used.
Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
cfg.save_pretrained(pretrained_dir)
if cfg.peft is not None:
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
@@ -140,7 +152,13 @@ def save_checkpoint(
if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir)
save_training_state(
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
checkpoint_dir,
step,
optimizer,
scheduler,
num_processes=num_processes,
batch_size=batch_size,
optim_state_dict=optim_state_dict,
)
@@ -151,6 +169,7 @@ def save_training_state(
scheduler: LRScheduler | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
optim_state_dict: dict | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
@@ -164,19 +183,21 @@ def save_training_state(
Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
optim_state_dict: Pre-gathered full optimizer state dict (for FSDP). Saved instead of
`optimizer.state_dict()` when provided. Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
save_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
if scheduler is not None:
save_scheduler_state(scheduler, save_dir)
def load_training_state(
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None, load_optimizer: bool = True
) -> tuple[int, Optimizer, LRScheduler | None]:
"""
Loads the training step, optimizer state, scheduler state, and rng state.
@@ -186,6 +207,10 @@ def load_training_state(
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
optimizer (Optimizer): The optimizer to load the state_dict to.
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
load_optimizer (bool, optional): Whether to load the optimizer state from disk. Defaults to
True. Set to False under FSDP, where the sharded optimizer state must be loaded after
`accelerator.prepare()` via `load_fsdp_optimizer_state` (the optimizer is returned
untouched here).
Raises:
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
@@ -200,8 +225,61 @@ def load_training_state(
load_rng_state(training_state_dir)
step = load_training_step(training_state_dir)
optimizer = load_optimizer_state(optimizer, training_state_dir)
if load_optimizer:
optimizer = load_optimizer_state(optimizer, training_state_dir)
if scheduler is not None:
scheduler = load_scheduler_state(scheduler, training_state_dir)
return step, optimizer, scheduler
def gather_fsdp_state_dicts(model, optimizer) -> tuple[dict, dict]:
"""Gather the full (unsharded) model and optimizer state dicts under FSDP.
`model.state_dict()` and `FSDP.optim_state_dict(...)` are cross-rank collectives, so this must be
called on *every* rank with the prepared (FSDP-wrapped) `model` and `optimizer`. With
`rank0_only=True` and `offload_to_cpu=True`, every rank runs the all-gather but only rank 0
materializes the full dicts (the others get empty dicts) and they are kept on CPU to bound GPU
memory. The returned optimizer state dict is keyed by parameter FQNs and is world-size
independent; `load_fsdp_optimizer_state` reshards it on resume.
Returns:
(model_state_dict, optim_state_dict): full dicts on rank 0, empty dicts on other ranks.
"""
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP, # noqa F401
StateDictType,
)
state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
optim_cfg = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
model_state_dict = model.state_dict()
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
return model_state_dict, optim_state_dict
def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
"""Load the FSDP optimizer state (saved as safetensors) and reshard it into the optimizer.
This is a cross-rank collective and must be called on every rank *after* `accelerator.prepare()`
with the prepared (FSDP-wrapped) `model` and `optimizer`. The saved state is the full,
world-size-independent optimizer state (keyed by parameter FQNs); `FSDP.optim_state_dict_to_load`
reshards it to the current FSDP topology, so resume on a different number of GPUs works.
"""
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP, # noqa F401
StateDictType,
)
# Every rank reads the same full state from the (shared) checkpoint dir, so rank0_only=False.
full_osd = load_optimizer_state_dict(checkpoint_dir / TRAINING_STATE_DIR)
state_cfg = FullStateDictConfig(rank0_only=False)
optim_cfg = FullOptimStateDictConfig(rank0_only=False)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
optimizer.load_state_dict(sharded_osd)
-9
View File
@@ -32,7 +32,6 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
from .io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
to_parquet_one_row_group_per_episode,
to_parquet_with_hf_images,
write_info,
write_stats,
@@ -552,7 +551,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
aggr_root=dst_meta.root,
hf_features=hf_features,
concatenate=concatenate_data,
one_row_group_per_episode=True,
)
# Record the mapping from source to actual destination
@@ -630,7 +628,6 @@ def append_or_create_parquet_file(
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
concatenate: bool = True,
one_row_group_per_episode: bool = False,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -648,8 +645,6 @@ def append_or_create_parquet_file(
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one.
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
the episodes-metadata parquet (already one row per episode).
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -662,8 +657,6 @@ def append_or_create_parquet_file(
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(df, dst_path)
else:
df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file)
@@ -690,8 +683,6 @@ def append_or_create_parquet_file(
if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(final_df, target_path)
else:
final_df.to_parquet(target_path)
+9 -38
View File
@@ -20,7 +20,6 @@ import datasets
import numpy as np
import pandas
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
@@ -271,49 +270,21 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
return items_dict
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
"""Write ``table`` with one parquet row group per episode (in episode order).
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
mirroring the recording writer. ``table`` must carry a contiguous
``episode_index`` column.
"""
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
try:
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
finally:
writer.close()
def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Images are embedded into the arrow table first (``ParquetWriter.write_table``
does not embed external image files like ``Dataset.to_parquet`` does).
``features`` types image columns as ``Image()`` in the parquet schema.
Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds = embed_images(ds)
table = ds.with_format("arrow")[:]
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
# No episode boundaries to align row groups to — keep a single write.
pq.write_table(table, str(path))
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
table = pa.Table.from_pandas(df, preserve_index=False)
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
pq.write_table(table, str(path))
ds.to_parquet(path)
def item_to_torch(item: dict) -> dict:
+3 -5
View File
@@ -70,21 +70,19 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]],
*,
use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `exclude_images` and `patterns`, and finally
(image or state), filters them based on `use_videos` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset.
Args:
pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
exclude_images: If True, image features are dropped entirely from the output.
use_videos: If False, image features are excluded.
patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter.
@@ -122,7 +120,7 @@ def aggregate_pipeline_dataset_features(
)
# 2. Apply filtering rules.
if is_image and exclude_images:
if is_image and not use_videos:
continue
if not is_image and not should_keep(key, compiled_patterns):
continue
+2
View File
@@ -20,6 +20,7 @@ from .optimizers import (
SGDConfig as SGDConfig,
XVLAAdamWConfig as XVLAAdamWConfig,
load_optimizer_state,
load_optimizer_state_dict,
save_optimizer_state,
)
from .schedulers import (
@@ -50,6 +51,7 @@ __all__ = [
"VQBeTSchedulerConfig",
# State management
"load_optimizer_state",
"load_optimizer_state_dict",
"load_scheduler_state",
"save_optimizer_state",
"save_scheduler_state",
+30 -5
View File
@@ -27,7 +27,7 @@ from lerobot.utils.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
from lerobot.utils.io_utils import deserialize_json_into_object, load_json, write_json
from lerobot.utils.utils import flatten_dict, unflatten_dict
# Type alias for parameters accepted by optimizer build() methods.
@@ -281,28 +281,37 @@ class MultiAdamConfig(OptimizerConfig):
def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer],
save_dir: Path,
optim_state_dict: dict | None = None,
) -> None:
"""Save optimizer state to disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state.
optim_state_dict: Pre-gathered optimizer state dict (for FSDP, where the sharded state must
be gathered across ranks first). If provided, it is saved directly instead of calling
``optimizer.state_dict()``. Only supported for a single optimizer. Defaults to None.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
if optim_state_dict is not None:
raise ValueError("optim_state_dict is not supported for a dict of optimizers")
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
optimizer_dir.mkdir(exist_ok=True, parents=True)
_save_single_optimizer_state(opt, optimizer_dir)
else:
# Handle single optimizer
_save_single_optimizer_state(optimizer, save_dir)
_save_single_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
def _save_single_optimizer_state(
optimizer: torch.optim.Optimizer, save_dir: Path, optim_state_dict: dict | None = None
) -> None:
"""Save a single optimizer's state to disk."""
state = optimizer.state_dict()
state = dict(optim_state_dict) if optim_state_dict is not None else optimizer.state_dict()
param_groups = state.pop("param_groups")
flat_state = flatten_dict(state)
save_file(flat_state, save_dir / OPTIMIZER_STATE)
@@ -356,3 +365,19 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
optimizer.load_state_dict(loaded_state_dict)
return optimizer
def load_optimizer_state_dict(save_dir: Path) -> dict:
"""Read a saved optimizer state dict (safetensors + json) back into a plain dict.
Unlike `load_optimizer_state`, this does not load into an optimizer and preserves the original
``state`` keys verbatim (e.g. FSDP parameter FQNs, which are not integer-castable). It is used by
the FSDP resume path, where the full state must be resharded via `FSDP.optim_state_dict_to_load`
before being loaded into the (sharded) optimizer.
"""
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
return {
"state": state.get("state", {}),
"param_groups": load_json(save_dir / OPTIMIZER_PARAM_GROUPS),
}
+39 -4
View File
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
import packaging
import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
@@ -129,10 +129,43 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None:
def save_pretrained(
self,
save_directory: str | Path,
*,
state_dict: dict[str, Tensor] | None = None,
repo_id: str | None = None,
push_to_hub: bool = False,
card_kwargs: dict | None = None,
**push_to_hub_kwargs,
) -> str | None:
"""Save the policy to a directory (and optionally push to the Hub).
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
return sharded tensors, so the caller gathers the full state dict via a cross-rank
collective and passes it here for `_save_pretrained` to write directly.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
self._save_pretrained(save_directory, state_dict=state_dict)
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return None
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
if state_dict is None:
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
return
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
@classmethod
def from_pretrained(
@@ -270,6 +303,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self,
cfg: TrainPipelineConfig,
peft_model=None,
state_dict: dict[str, Tensor] | None = None,
):
api = HfApi()
repo_id = api.create_repo(
@@ -287,7 +321,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
peft_model.save_pretrained(saved_path)
self.config.save_pretrained(saved_path)
else:
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
# Calls _save_pretrained and stores model tensors
self.save_pretrained(saved_path, state_dict=state_dict)
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__)
class BiOpenArmFollower(BimanualMixin, Robot):
class BiOpenArmFollower(Robot):
"""
Bimanual OpenArm Follower Arms
"""
@@ -40,17 +39,15 @@ class BiOpenArmFollower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
# will only open the cameras assigned to it. Per-arm cameras are used
# as fallback when top-level cameras are empty.
if config.cameras:
left_cameras = config.cameras
right_cameras = {}
else:
left_cameras = config.left_arm_config.cameras
right_cameras = config.right_arm_config.cameras
left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None,
@@ -59,7 +56,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras,
cameras=left_cameras,
side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
@@ -78,7 +75,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras,
cameras=right_cameras,
side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
@@ -98,19 +95,22 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@property
def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return {
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
}
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
# "right_wrist"), so we merge them directly — unlike motors which need the
# left_/right_ prefix to disambiguate identical per-arm joint names.
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -120,6 +120,27 @@ class BiOpenArmFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -127,15 +148,21 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Camera keys that should NOT get the arm prefix (they already have unique names)
left_cam_keys = set(self.left_arm.cameras.keys())
right_cam_keys = set(self.right_arm.cameras.keys())
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
right_obs = self.right_arm.get_observation()
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict
@@ -162,4 +189,9 @@ class BiOpenArmFollower(BimanualMixin, Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -32,7 +32,5 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
# Top-level cameras shared across both arms.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__)
class BiRebotB601Follower(BimanualMixin, Robot):
class BiRebotB601Follower(Robot):
"""Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
@@ -42,18 +41,6 @@ class BiRebotB601Follower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -62,7 +49,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras,
cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
@@ -99,12 +86,10 @@ class BiRebotB601Follower(BimanualMixin, Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
return {
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -114,13 +99,32 @@ class BiRebotB601Follower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
for k, v in self.left_arm.get_observation().items():
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm.get_observation().items():
obs_dict[f"right_{k}"] = v
obs_dict = {}
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
return obs_dict
@check_if_not_connected
@@ -139,3 +143,8 @@ class BiRebotB601Follower(BimanualMixin, Robot):
**{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()},
}
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from dataclasses import dataclass
from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig
@@ -29,8 +27,3 @@ class BiRebotB601FollowerConfig(RobotConfig):
left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from ..so_follower import SOFollower, SOFollowerRobotConfig
@@ -28,7 +27,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__)
class BiSOFollower(BimanualMixin, Robot):
class BiSOFollower(Robot):
"""
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -40,18 +39,6 @@ class BiSOFollower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = SOFollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -59,7 +46,7 @@ class BiSOFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
use_degrees=config.left_arm_config.use_degrees,
cameras=left_arm_cameras,
cameras=config.left_arm_config.cameras,
)
right_arm_config = SOFollowerRobotConfig(
@@ -90,12 +77,13 @@ class BiSOFollower(BimanualMixin, Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
left_arm_cameras_ft = self.left_arm._cameras_ft
right_arm_cameras_ft = self.right_arm._cameras_ft
return {
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -105,21 +93,42 @@ class BiSOFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Add "left_" prefix
left_obs = self.left_arm.get_observation()
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
right_obs = self.right_arm.get_observation()
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
return obs_dict
@@ -142,3 +151,8 @@ class BiSOFollower(BimanualMixin, Robot):
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from dataclasses import dataclass
from ..config import RobotConfig
from ..so_follower import SOFollowerConfig
@@ -29,8 +27,3 @@ class BiSOFollowerConfig(RobotConfig):
left_arm_config: SOFollowerConfig
right_arm_config: SOFollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
-1
View File
@@ -54,7 +54,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
@@ -57,7 +57,6 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
-1
View File
@@ -137,7 +137,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
-1
View File
@@ -174,7 +174,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
@@ -41,7 +41,6 @@ from lerobot.robots import ( # noqa: F401
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
koch_leader,
@@ -89,7 +89,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
+28 -4
View File
@@ -34,8 +34,10 @@ from torch.optim import Optimizer
from tqdm import tqdm
from lerobot.common.train_utils import (
gather_fsdp_state_dicts,
get_step_checkpoint_dir,
get_step_identifier,
load_fsdp_optimizer_state,
load_training_batch_size,
load_training_num_processes,
load_training_state,
@@ -189,6 +191,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
require_package("accelerate", extra="training")
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
cfg.validate()
@@ -197,8 +200,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
@@ -370,7 +371,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
# Under FSDP the optimizer state is sharded and must be loaded after `accelerator.prepare()`
# (see load_fsdp_optimizer_state below), so skip the optimizer here and load it then.
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
step, optimizer, lr_scheduler = load_training_state(
cfg.checkpoint_path, optimizer, lr_scheduler, load_optimizer=not is_fsdp
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -460,6 +466,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
# FSDP optimizer state is sharded across ranks, so it can only be loaded once the optimizer and
# model are FSDP-wrapped (i.e. after `prepare`). Collective: every rank must participate.
if cfg.resume and accelerator.distributed_type == DistributedType.FSDP:
load_fsdp_optimizer_state(policy, optimizer, cfg.checkpoint_path)
dl_iter = cycle(dataloader)
policy.train()
@@ -558,6 +570,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
# Under FSDP, gathering the full model + optimizer state dicts is a cross-rank collective,
# so all ranks must participate; rank 0 then writes the materialized dicts. For DDP /
# single-GPU the state dicts are saved the normal way inside save_checkpoint.
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
if is_fsdp:
model_state_dict, optim_state_dict = gather_fsdp_state_dicts(policy, optimizer)
else:
model_state_dict, optim_state_dict = None, None
if is_main_process:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
@@ -572,6 +592,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
postprocessor=postprocessor,
num_processes=accelerator.num_processes,
batch_size=cfg.batch_size,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
@@ -634,6 +656,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if eval_env:
close_envs(eval_env)
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None
if is_main_process:
logging.info("End of training")
@@ -643,7 +667,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if not cfg.is_reward_model_training and cfg.policy.use_peft:
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
else:
unwrapped_model.push_model_to_hub(cfg)
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
from ..teleoperator import Teleoperator
@@ -28,7 +27,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
logger = logging.getLogger(__name__)
class BiOpenArmLeader(BimanualMixin, Teleoperator):
class BiOpenArmLeader(Teleoperator):
"""
Bimanual OpenArm Leader Arms
"""
@@ -87,6 +86,27 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -109,3 +129,8 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
@dataclass
class BiOpenArmLeaderConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Leader teleoperators."""
"""Configuration class for Bi OpenArm Follower robots."""
left_arm_config: OpenArmLeaderConfigBase
right_arm_config: OpenArmLeaderConfigBase
@@ -1,20 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_mini import BiOpenArmMini
from .config_bi_openarm_mini import BiOpenArmMiniConfig
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
@@ -1,101 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
from ..teleoperator import Teleoperator
from .config_bi_openarm_mini import BiOpenArmMiniConfig
logger = logging.getLogger(__name__)
class BiOpenArmMini(BimanualMixin, Teleoperator):
"""Bimanual OpenArm Mini teleoperator.
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
bimanual leader can teleoperate a bimanual OpenArm follower.
"""
config_class = BiOpenArmMiniConfig
name = "bi_openarm_mini"
def __init__(self, config: BiOpenArmMiniConfig):
super().__init__(config)
self.config = config
# `side` is forced to match left/right regardless of what the user passed
# on the per-arm base config — the bimanual wrapper owns the side semantics.
left_arm_config = OpenArmMiniConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
side="left",
use_degrees=config.left_arm_config.use_degrees,
)
right_arm_config = OpenArmMiniConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
side="right",
use_degrees=config.right_arm_config.use_degrees,
)
self.left_arm = OpenArmMini(left_arm_config)
self.right_arm = OpenArmMini(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
}
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
action: RobotAction = {}
for k, v in self.left_arm.get_action().items():
action[f"left_{k}"] = v
for k, v in self.right_arm.get_action().items():
action[f"right_{k}"] = v
return action
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
if left_fb:
self.left_arm.send_feedback(left_fb)
if right_fb:
self.right_arm.send_feedback(right_fb)
@@ -1,29 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
from ..openarm_mini import OpenArmMiniConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
@dataclass
class BiOpenArmMiniConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Mini teleoperators."""
left_arm_config: OpenArmMiniConfigBase
right_arm_config: OpenArmMiniConfigBase
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_rebot_102_leader import BiRebot102Leader
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
from .bi_rebot_102_leader import BiRebotArm102Leader
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
@@ -18,17 +18,16 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
from ..teleoperator import Teleoperator
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
logger = logging.getLogger(__name__)
class BiRebot102Leader(BimanualMixin, Teleoperator):
class BiRebotArm102Leader(Teleoperator):
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
@@ -36,10 +35,10 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
leader can teleoperate a bimanual reBot B601 follower.
"""
config_class = BiRebot102LeaderConfig
config_class = BiRebotArm102LeaderConfig
name = "bi_rebot_102_leader"
def __init__(self, config: BiRebot102LeaderConfig):
def __init__(self, config: BiRebotArm102LeaderConfig):
super().__init__(config)
self.config = config
@@ -77,6 +76,27 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_action(self) -> RobotAction:
action_dict = {}
@@ -86,3 +106,8 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
@dataclass
class BiRebot102LeaderConfig(TeleoperatorConfig):
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
left_arm_config: RebotArm102LeaderConfig
@@ -17,9 +17,7 @@
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..so_leader import SOLeader, SOLeaderTeleopConfig
from ..teleoperator import Teleoperator
@@ -28,7 +26,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
logger = logging.getLogger(__name__)
class BiSOLeader(BimanualMixin, Teleoperator):
class BiSOLeader(Teleoperator):
"""
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -69,12 +67,33 @@ class BiSOLeader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
def get_action(self) -> dict[str, float]:
action_dict = {}
# Add "left_" prefix
@@ -90,3 +109,8 @@ class BiSOLeader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
from .config_openarm_mini import OpenArmMiniConfig
from .openarm_mini import OpenArmMini
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
@@ -19,21 +19,12 @@ from dataclasses import dataclass
from ..config import TeleoperatorConfig
@dataclass
class OpenArmMiniConfigBase:
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
port: str
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
# during readout. If `None`, no flipping is applied.
side: str | None = None
use_degrees: bool = True
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
pass
class OpenArmMiniConfig(TeleoperatorConfig):
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
port_right: str = "/dev/ttyUSB0"
port_left: str = "/dev/ttyUSB1"
use_degrees: bool = True
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Per-side motor direction flips applied during readout.
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
}
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
# Leader joint 6 follower joint 7 (symmetric — its own inverse).
# Leader joint 6 maps to follower joint 7 and vice versa
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator):
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
"""
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
"""
config_class = OpenArmMiniConfig
@@ -56,12 +56,9 @@ class OpenArmMini(Teleoperator):
super().__init__(config)
self.config = config
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
norm_mode_body = MotorNormMode.DEGREES
motors = {
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
@@ -72,15 +69,46 @@ class OpenArmMini(Teleoperator):
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
self.bus = FeetechMotorsBus(
port=self.config.port,
motors=motors,
calibration=self.calibration,
motors_left = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
)
@property
def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property
def feedback_features(self) -> dict[str, type]:
@@ -88,12 +116,14 @@ class OpenArmMini(Teleoperator):
@property
def is_connected(self) -> bool:
return self.bus.is_connected
return self.bus_right.is_connected and self.bus_left.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting arm on {self.config.port}...")
self.bus.connect()
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate:
self.calibrate()
@@ -103,14 +133,14 @@ class OpenArmMini(Teleoperator):
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for a single OpenArm Mini arm.
Run calibration procedure for OpenArm Mini.
1. Disable torque
2. Ask user to position arm in hanging position with gripper closed
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions)
5. Save calibration
@@ -122,51 +152,70 @@ class OpenArmMini(Teleoperator):
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
self.bus.write_calibration(self.calibration)
cal_right = {
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return
logger.info(f"\nRunning calibration for {self}")
self.bus.disable_torque()
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
logger.info("Setting Phase to 12 for all motors...")
for motor in self.bus.motors:
self.bus.write("Phase", motor, 12)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
"""Calibrate a single arm with Feetech motors."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input(
"\nCalibration: Zero Position\n"
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
homing_offsets = self.bus.set_half_turn_homings()
logger.info("Arm zero position set.")
homing_offsets = bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
print("\nSetting motor ranges\n")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
if self.calibration is None:
self.calibration = {}
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
max_res = motor_resolution - 1
for motor_name, motor in self.bus.motors.items():
for motor_name, motor in bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper":
input(
"\nGripper Calibration\n"
"Step 1: CLOSE the gripper fully\n"
"Press ENTER when gripper is closed..."
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
f"Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..."
)
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
open_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos:
@@ -179,16 +228,16 @@ class OpenArmMini(Teleoperator):
drive_mode = 1
logger.info(
f" {motor_name}: range set to [{range_min}, {range_max}] "
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})"
)
else:
range_min = 0
range_max = max_res
drive_mode = 0
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[motor_name] = MotorCalibration(
self.calibration[prefixed_name] = MotorCalibration(
id=motor.id,
drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name],
@@ -196,68 +245,108 @@ class OpenArmMini(Teleoperator):
range_max=range_max,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
cal_for_bus = {
k.replace(f"{arm_name}_", ""): v
for k, v in self.calibration.items()
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None:
self.bus.disable_torque()
self.bus.configure_motors()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_right.disable_torque()
self.bus_right.configure_motors()
for motor in self.bus_right.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
print("\nSetting up RIGHT arm motors...")
for motor in reversed(self.bus_right.motors):
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
self.bus_right.setup_motor(motor)
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> RobotAction:
"""Get current action (read positions from all motors)."""
"""Get current action from both arms (read positions from all motors)."""
start = time.perf_counter()
positions = self.bus.sync_read("Present_Position")
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
# Per-side direction flip is applied based on the configured `side`.
action: dict[str, Any] = {}
for motor, val in positions.items():
for motor, val in right_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def enable_torque(self) -> None:
self.bus.enable_torque()
"""Enable torque on both arms for position control."""
self.bus_right.enable_torque()
self.bus_left.enable_torque()
def disable_torque(self) -> None:
self.bus.disable_torque()
"""Disable torque on both arms for free movement."""
self.bus_right.disable_torque()
self.bus_left.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
goals: dict[str, float] = {}
right_goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
for key, val in positions.items():
if not key.endswith(".pos"):
continue
base = key.removesuffix(".pos")
# JOINT_REMAP is symmetric (its own inverse).
target = JOINT_REMAP.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
goals[target] = -val if target in self._motors_to_flip else val
motor_name = key.removesuffix(".pos")
if motor_name.startswith("right_"):
base = motor_name.removeprefix("right_")
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
if goals:
self.bus.sync_write("Goal_Position", goals)
if right_goals:
self.bus_right.sync_write("Goal_Position", right_goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
@@ -265,5 +354,6 @@ class OpenArmMini(Teleoperator):
@check_if_not_connected
def disconnect(self) -> None:
self.bus.disconnect()
self.bus_right.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.")
+2 -6
View File
@@ -99,18 +99,14 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .openarm_mini import OpenArmMini
return OpenArmMini(config)
elif config.type == "bi_openarm_mini":
from .bi_openarm_mini import BiOpenArmMini
return BiOpenArmMini(config)
elif config.type == "rebot_102_leader":
from .rebot_102_leader import RebotArm102Leader
return RebotArm102Leader(config)
elif config.type == "bi_rebot_102_leader":
from .bi_rebot_102_leader import BiRebot102Leader
from .bi_rebot_102_leader import BiRebotArm102Leader
return BiRebot102Leader(config)
return BiRebotArm102Leader(config)
else:
try:
return cast("Teleoperator", make_device_from_device_class(config))
-63
View File
@@ -1,63 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
class BimanualMixin:
"""Lifecycle delegation for bimanual robots and teleoperators.
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
their own ``__init__``. They retain ownership of feature dicts and the
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
``send_feedback``), which vary per-embodiment.
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
take precedence in the MRO::
class BiFooFollower(BimanualMixin, Robot): ...
"""
left_arm: Any
right_arm: Any
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
-73
View File
@@ -28,7 +28,6 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
@@ -345,78 +344,6 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
assert len(dataset) == 24
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
for ep, length in enumerate(episode_lengths):
episode_index += [ep] * length
frame_index += list(range(length))
timestamp += [round(i / fps, 6) for i in range(length)]
task_index += [0] * length
subtask_index += [0] * length # legacy column the writer must drop
pd.DataFrame(
{
"episode_index": episode_index,
"frame_index": frame_index,
"timestamp": timestamp,
"task_index": task_index,
"subtask_index": subtask_index,
}
).to_parquet(data_dir / "file-000.parquet", index=False)
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
write_tasks(tasks_df, root)
write_json(
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
root / "meta" / "info.json",
)
return root
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
"""Rewriting a packed shard must keep one row group per episode, not collapse
every episode into a single giant row group."""
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
shard = root / "data" / "chunk-000" / "file-000.parquet"
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
staging_dir = tmp_path / "stage"
for ep in range(len(episode_lengths)):
_stage_episode(
staging_dir,
ep,
plan=[
{
"role": "assistant",
"content": f"subtask for ep {ep}",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
)
records = list(iter_episodes(root))
LanguageColumnsWriter().write_all(records, staging_dir, root)
# One row group per episode, with row counts matching the episode lengths.
md = pq.ParquetFile(shard).metadata
assert md.num_row_groups == len(episode_lengths)
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
# Language columns are still present after the per-episode rewrite.
table = pq.read_table(shard)
assert "language_persistent" in table.column_names
assert "language_events" in table.column_names
def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant"
-55
View File
@@ -32,26 +32,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def assert_data_shards_one_row_group_per_episode(root):
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
import pyarrow.parquet as pq
shards = sorted((root / "data").rglob("*.parquet"))
assert shards, f"no data shards found under {root}/data"
n_episodes = 0
for shard in shards:
pf = pq.ParquetFile(shard)
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
assert pf.metadata.num_row_groups == len(set(episodes)), shard
for i in range(pf.metadata.num_row_groups):
rg_episodes = set(
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
)
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
n_episodes += len(set(episodes))
return n_episodes
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, (
@@ -586,41 +566,6 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
)
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
Covers both the non-image (``df.to_parquet``) and image
(``to_parquet_with_hf_images``) write branches, including the merge-into-
existing-file branch via a low file-size threshold that forces packing.
"""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "rg_0",
repo_id=f"{DUMMY_REPO_ID}_rg_0",
total_episodes=3,
total_frames=60,
use_videos=use_videos,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "rg_1",
repo_id=f"{DUMMY_REPO_ID}_rg_1",
total_episodes=4,
total_frames=80,
use_videos=use_videos,
)
aggr_root = tmp_path / "rg_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
aggr_root=aggr_root,
)
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
+39
View File
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import (
MultiAdamConfig,
SGDConfig,
load_optimizer_state,
load_optimizer_state_dict,
save_optimizer_state,
)
from lerobot.utils.constants import (
@@ -65,6 +66,44 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
def test_save_and_load_fsdp_optimizer_state_dict_roundtrip(tmp_path):
"""The FSDP full optimizer state dict is keyed by parameter FQNs (dotted strings), not the
integer indices of the single-GPU path. Verify it survives the safetensors save -> read
round-trip used by the FSDP save/resume path (save_optimizer_state(optim_state_dict=...) then
load_optimizer_state_dict), which the flatten/unflatten "/" separator must not corrupt."""
full_osd = {
"state": {
"model.layers.0.weight": {
"step": torch.tensor(3.0),
"exp_avg": torch.randn(4, 4),
"exp_avg_sq": torch.randn(4, 4),
},
"model.layers.0.bias": {
"step": torch.tensor(3.0),
"exp_avg": torch.randn(4),
"exp_avg_sq": torch.randn(4),
},
},
"param_groups": [
{"lr": 1e-4, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.0, "params": [0, 1]}
],
}
save_optimizer_state(
torch.optim.Adam([torch.nn.Parameter(torch.randn(1))]), tmp_path, optim_state_dict=full_osd
)
assert (tmp_path / OPTIMIZER_STATE).is_file()
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
loaded = load_optimizer_state_dict(tmp_path)
# FQN keys must be preserved verbatim (not int-cast, not split on their dots).
assert set(loaded["state"].keys()) == set(full_osd["state"].keys())
for fqn, sub in full_osd["state"].items():
for k, v in sub.items():
torch.testing.assert_close(loaded["state"][fqn][k], v)
assert loaded["param_groups"] == full_osd["param_groups"]
@pytest.fixture
def base_params_dict():
return {
+24
View File
@@ -23,6 +23,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from packaging import version
from safetensors.torch import load_file
@@ -300,6 +301,29 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
policy_cls = get_policy_class("act")
policy_cfg = make_policy_config("act")
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
policy.to(policy_cfg.device)
save_dir = tmp_path / "fsdp_state_dict"
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
# A single, unsharded safetensors file (no sharded set + index).
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool):
"""
+3 -21
View File
@@ -2370,32 +2370,14 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # images kept, stored as "image" dtype
use_videos=False, # expect "image" dtype
patterns=None,
)
key = f"{OBS_IMAGES}.back"
key_front = f"{OBS_IMAGES}.front"
assert key in out
assert key_front in out
assert out[key]["dtype"] == "image"
assert out[key_front]["dtype"] == "image"
assert out[key]["shape"] == initial["back"]
def test_aggregate_images_excluded():
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
initial = {"back": (480, 640, 3)}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
exclude_images=True,
patterns=None,
)
assert f"{OBS_IMAGES}.back" not in out
assert f"{OBS_IMAGES}.front" not in out
assert key not in out
assert key_front not in out
def test_aggregate_images_when_use_videos_true():
+3 -3
View File
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
import pytest
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
from lerobot.teleoperators.rebot_102_leader import (
RebotArm102Leader,
RebotArm102LeaderConfig,
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebot102LeaderConfig(
cfg = BiRebotArm102LeaderConfig(
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
)
teleop = BiRebot102Leader(cfg)
teleop = BiRebotArm102Leader(cfg)
assert any(k.startswith("left_") for k in teleop.action_features)
assert any(k.startswith("right_") for k in teleop.action_features)
assert "left_gripper.pos" in teleop.action_features
+108 -13
View File
@@ -58,7 +58,46 @@ def download_dataset(repo_id, episodes):
print(f"Dataset {repo_id} downloaded successfully")
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
def _write_multi_gpu_config(f, num_processes):
f.write("compute_environment: LOCAL_MACHINE\n")
f.write("distributed_type: MULTI_GPU\n")
f.write("mixed_precision: 'no'\n")
f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("downcast_bf16: 'no'\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
def _write_fsdp_config(f, num_processes):
# FSDP1 with FULL_SHARD (ZeRO-3-equivalent) and FULL_STATE_DICT, matching
# docs/source/multi_gpu_training.mdx. ACT's repeated transformer blocks are the wrap units;
# fsdp_use_orig_params is required because LeRobot builds the optimizer before prepare().
f.write("compute_environment: LOCAL_MACHINE\n")
f.write("distributed_type: FSDP\n")
f.write("mixed_precision: 'no'\n")
f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
f.write("fsdp_config:\n")
f.write(" fsdp_version: 1\n")
f.write(" fsdp_sharding_strategy: FULL_SHARD\n")
f.write(" fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n")
f.write(" fsdp_transformer_layer_cls_to_wrap: ACTEncoderLayer,ACTDecoderLayer\n")
f.write(" fsdp_use_orig_params: true\n")
f.write(" fsdp_state_dict_type: FULL_STATE_DICT\n")
def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distributed_type="MULTI_GPU"):
"""
Helper function to run training with accelerate launch.
@@ -66,6 +105,7 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
config_args: List of config arguments to pass to lerobot_train.py
num_processes: Number of processes (GPUs) to use
temp_dir: Temporary directory for outputs
distributed_type: "MULTI_GPU" (DDP) or "FSDP" selects the generated accelerate config.
Returns:
subprocess.CompletedProcess result
@@ -75,18 +115,10 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
# Write YAML config
with open(config_path, "w") as f:
f.write("compute_environment: LOCAL_MACHINE\n")
f.write("distributed_type: MULTI_GPU\n")
f.write("mixed_precision: 'no'\n")
f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("downcast_bf16: 'no'\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
if distributed_type == "FSDP":
_write_fsdp_config(f, num_processes)
else:
_write_multi_gpu_config(f, num_processes)
cmd = [
"accelerate",
@@ -211,3 +243,66 @@ class TestMultiGPUTraining:
# Verify optimizer state exists
optimizer_state = training_state_dir / "optimizer_state.safetensors"
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
def test_fsdp_optimizer_save_and_resume(self):
"""
Test that FSDP saves the (gathered) optimizer state and can resume from it.
Trains a few steps under FSDP, verifies the gathered optimizer state is written next to the
rest of the training state, then resumes from the checkpoint for more steps and checks it
completes without shape/key errors in the FSDP optimizer load path.
"""
# Pre-download dataset to avoid race conditions
download_dataset("lerobot/pusht", episodes=[0])
with tempfile.TemporaryDirectory() as temp_dir:
output_dir = Path(temp_dir) / "outputs"
config_args = [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--policy.push_to_hub=false",
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=10",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
"--num_workers=0",
]
result = run_accelerate_training(
config_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
)
assert result.returncode == 0, (
f"FSDP training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
)
# The gathered optimizer state must be written under FSDP (proves the save collective ran),
# in the same safetensors format as single-GPU training.
training_state_dir = output_dir / "checkpoints" / "last" / "training_state"
optimizer_state = training_state_dir / "optimizer_state.safetensors"
optimizer_param_groups = training_state_dir / "optimizer_param_groups.json"
assert optimizer_state.exists(), f"FSDP optimizer state not saved in {training_state_dir}"
assert optimizer_param_groups.exists(), (
f"FSDP optimizer param groups not saved in {training_state_dir}"
)
# Resume from the checkpoint for more steps. A successful run proves load_fsdp_optimizer
# accepts the saved state and reshards it without shape/key errors.
resume_config = output_dir / "checkpoints" / "last" / "pretrained_model" / "train_config.json"
resume_args = [
f"--config_path={resume_config}",
"--resume=true",
"--steps=20",
]
resume_result = run_accelerate_training(
resume_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
)
assert resume_result.returncode == 0, (
f"FSDP resume failed:\nSTDOUT:\n{resume_result.stdout}\n\nSTDERR:\n{resume_result.stderr}"
)
assert "End of training" in resume_result.stdout or "End of training" in resume_result.stderr
+15
View File
@@ -136,3 +136,18 @@ def test_save_load_training_state(tmp_path, optimizer, scheduler):
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler
def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
# FSDP loads optimizer separately (after accelerator.prepare)
# load_training_state(load_optimizer=False) must restore step + scheduler but leave the
# optimizer untouched and never touch the on-disk optimizer state.
save_training_state(tmp_path, 10, optimizer, scheduler)
with patch("lerobot.common.train_utils.load_optimizer_state") as mock_load_optimizer_state:
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
tmp_path, optimizer, scheduler, load_optimizer=False
)
mock_load_optimizer_state.assert_not_called()
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler
Generated
+900 -949
View File
File diff suppressed because it is too large Load Diff