Compare commits

..

5 Commits

Author SHA1 Message Date
Martino Russi 1c118c6359 feat(unitree_g1): add SONIC whole-body controller
Move GrootLocomotionController and HolosomaLocomotionController into a new
controllers/ subpackage and add the SONIC whole-body controller
(sonic_pipeline.py, sonic_whole_body.py) plus the examples/unitree_g1/sonic.py
standalone script. UnitreeG1 now honors a controller's kp/kd, calls
controller.shutdown() on disconnect, and skips arm publishing for full_body
controllers.
2026-06-16 17:12:20 +02:00
Pepijn 58ccc01508 fix(datasets): enforce one parquet row group per episode in v3 data writes (#3807)
* fix(datasets): enforce one parquet row group per episode in v3 data writes

LeRobot v3 data shards must hold exactly one row group per episode so a
reader can fetch episode i with pq.ParquetFile(path).read_row_group(i)
(a byte-range read) instead of loading the whole shard. The recording
writer already does this (one write_table per episode); the aggregate
and lerobot-annotate re-write paths instead concatenated many episodes
and wrote them in one shot, collapsing the file to a single row group.

- io_utils: add write_table_one_row_group_per_episode (one ParquetWriter,
  one write_table per episode — same pattern as the recording writer);
  to_parquet_with_hf_images embeds images then writes per-episode row
  groups; to_parquet_one_row_group_per_episode wraps it for plain frames
- aggregate: route non-image data writes through the per-episode writer;
  leave the episodes-metadata parquet untouched (already one row/episode)
- annotate: rewrite shards via the per-episode writer instead of a single
  bulk pq.write_table
- tests: invariant coverage through the aggregate (image + video) and
  annotate paths

No change to on-disk schema, paths, naming, rollover thresholds, or
compression. Readers stay backward-compatible (old collapsed files load).

* Update src/lerobot/datasets/io_utils.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update src/lerobot/datasets/io_utils.py

Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* fix(datasets): correct indentation and add strict= in row-group helper

The web-edited numpy version of write_table_one_row_group_per_episode had an
over-indented line (IndentationError, breaking pre-commit + test collection)
and a zip() without strict=. Fix both; behaviour unchanged.

---------

Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-06-16 12:15:48 +02:00
Caroline Pascal 38327fdc84 fix(images/videos): fixing aggregate_pipeline_dataset_features to avoid unwanted images features deletion (#3783)
* fix(images/videos): fixing aggregate_pipeline_dataset_features to avoid unwanted images features deletion when videos are not used

* fix(docstrings): improving docstrings

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>

---------

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
2026-06-15 17:55:52 +02:00
Steven Palma 9555efc02c chore(dependencies): update uv.lock (#3595)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-06-15 16:29:44 +02:00
Steven Palma d576c59afb refactor(robots): homogenize bi-manual setups implementations (#3772)
* chore(robots): homogenize bi setups

* feat(robots): split openarm mini into single and bi

* refactor(robots): mixin for bi classes

* docs: update docs
2026-06-15 16:28:54 +02:00
63 changed files with 2944 additions and 1520 deletions
+3 -3
View File
@@ -167,9 +167,9 @@ jobs:
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
# immediately runs eval inside the training loop (env_eval_freq=1, 1 episode).
# immediately runs eval inside the training loop (eval_freq=1, 1 episode).
# Tests the full train→eval-within-training pipeline end-to-end.
- name: Run Libero train+eval smoke (1 step, env_eval_freq=1)
- name: Run Libero train+eval smoke (1 step, eval_freq=1)
if: env.HF_USER_TOKEN != ''
run: |
docker run --name libero-train-smoke --gpus all \
@@ -196,7 +196,7 @@ jobs:
--output_dir=/tmp/train-smoke \
--steps=1 \
--batch_size=1 \
--env_eval_freq=1 \
--eval_freq=1 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--eval.use_async_envs=false \
+4 -4
View File
@@ -58,7 +58,7 @@ test-act-ete-train:
--dataset.episodes="[0]" \
--batch_size=2 \
--steps=4 \
--env_eval_freq=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_freq=2 \
@@ -96,7 +96,7 @@ test-diffusion-ete-train:
--dataset.episodes="[0]" \
--batch_size=2 \
--steps=2 \
--env_eval_freq=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
@@ -126,7 +126,7 @@ test-tdmpc-ete-train:
--dataset.episodes="[0]" \
--batch_size=2 \
--steps=2 \
--env_eval_freq=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
@@ -161,7 +161,7 @@ test-smolvla-ete-train:
--dataset.episodes="[0]" \
--batch_size=2 \
--steps=4 \
--env_eval_freq=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_freq=2 \
+8 -8
View File
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
**Compatible teleoperators:**
- `openarm_mini` - OpenArm Mini
- `bi_openarm_mini` - Bimanual OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \
+1 -1
View File
@@ -719,7 +719,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
"num_workers": 4,
"steps": 5000,
"log_freq": 10,
"env_eval_freq": 1000,
"eval_freq": 1000,
"save_freq": 1000,
"save_checkpoint": true,
"seed": 2,
+1 -1
View File
@@ -117,7 +117,7 @@ lerobot-rollout \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=openarm_mini \
--teleop.type=bi_openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt"
```
+1 -1
View File
@@ -143,7 +143,7 @@ lerobot-train \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env_eval_freq=1000
--eval_freq=1000
```
## Reproducing published results
+1 -1
View File
@@ -173,7 +173,7 @@ lerobot-train \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env_eval_freq=1000
--eval_freq=1000
```
## Relationship to LIBERO
+2 -2
View File
@@ -120,11 +120,11 @@ lerobot-train \
--batch_size=4 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env_eval_freq=1000
--eval_freq=1000
```
## Practical tips
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
- Adjust `batch_size`, `steps`, and `env_eval_freq` to match your compute budget.
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
+2 -2
View File
@@ -103,7 +103,7 @@ accelerate launch \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--env_eval_freq=-1 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
@@ -142,7 +142,7 @@ accelerate launch \
--batch_size=32 \
--num_workers=4 \
--log_freq=20 \
--env_eval_freq=-1 \
--eval_freq=-1 \
--save_checkpoint=true \
--save_freq=2000
```
+1 -1
View File
@@ -314,7 +314,7 @@ lerobot-train \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--env_eval_freq=1000 \
--eval_freq=1000 \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.horizon=32 \
+1 -1
View File
@@ -166,7 +166,7 @@ lerobot-train \
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
--steps=100000 \
--batch_size=4 \
--env_eval_freq=5000 \
--eval_freq=5000 \
--eval.batch_size=1 \
--eval.n_episodes=5 \
--save_freq=10000
+1 -1
View File
@@ -165,7 +165,7 @@ lerobot-train \
--output_dir=./outputs/smolvla_vlabench_primitive \
--steps=100000 \
--batch_size=4 \
--env_eval_freq=5000 \
--eval_freq=5000 \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--save_freq=10000
+217
View File
@@ -0,0 +1,217 @@
#!/usr/bin/env python
"""
SONIC planner with full mode control.
Keyboard controls:
N / P - next / previous motion set
1-8 - select mode within current set
WASD - movement direction
Q / E - rotate facing left / right
9 / 0 - decrease / increase speed
- / = - decrease / increase height
R - force replan
Space - emergency stop -> IDLE
Esc - quit
Gamepad controls (Unitree wireless controller):
Left stick Y - speed (forward = fast, back = stop)
Left stick X - movement direction (offset from facing)
Right stick X - facing direction (incremental rotation)
Right stick Y - height (up = tall 0.8m, down = low 0.1m)
Buttons - unused (mode selection is keyboard-only)
For teleop integration use --robot.controller=SonicWholeBodyController instead.
"""
import argparse
import gc
import time
import numpy as np
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.controllers.sonic_whole_body import SonicRuntime
from lerobot.robots.unitree_g1.controllers.sonic_pipeline import (
CONTROL_DT,
DEFAULT_ANGLES,
LM,
MOTION_SETS,
RawKeyboard,
compute_kp_kd,
drain_keyboard,
)
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
def main():
parser = argparse.ArgumentParser(description="SONIC planner with keyboard + gamepad control")
parser.add_argument("--ip", type=str, default=None,
help="Robot IP for real hardware (e.g. 192.168.123.164). "
"Omit for simulation.")
parser.add_argument("--log-csv", action="store_true",
help="Write /tmp/sonic_pose_log.csv (disabled by default for teleop perf)")
parser.add_argument("--cpu", action="store_true",
help="Force CPU ONNX Runtime (skip CUDA even if onnxruntime-gpu is installed)")
parser.add_argument("--headless", action="store_true",
help="Ignored for sim (stock UnitreeG1 uses hub MuJoCo defaults)")
parser.add_argument("--gamepad", action="store_true",
help="Read Unitree wireless gamepad in sim (default: keyboard-only in sim)")
parser.add_argument("--keyboard-only", action="store_true",
help="Ignore wireless gamepad (terminal keyboard only)")
args = parser.parse_args()
print("=" * 60)
print("SONIC planner - full mode control")
print(" N/P cycle sets | 1-8 select mode | WASD move")
print(" Q/E rotate | 9/0 speed | -/= height")
print(" R replan | Space IDLE | Esc quit")
if args.ip:
print(f" Robot IP: {args.ip}")
else:
print(" Mode: simulation")
print("=" * 60 + "\n")
cfg = UnitreeG1Config(controller=None) # full-body SONIC; standalone loop owns publish
if args.ip:
cfg.is_simulation = False
cfg.robot_ip = args.ip
else:
cfg.is_simulation = True
if args.headless:
print("[Note] --headless ignored: sim uses stock UnitreeG1 + hub env")
robot = UnitreeG1(cfg)
robot.connect()
kp, kd = compute_kp_kd()
robot.kp = kp.copy()
robot.kd = kd.copy()
runtime = SonicRuntime(force_cpu=args.cpu)
controller = runtime.controller
ms = runtime.ms
runtime.controller.print_input_diagnostics()
print(f"\nStarting: {MOTION_SETS[0][0]} (default mode: {LM(ms.mode).name})")
[print(f" {i+1}: {m.name}") for i, m in enumerate(MOTION_SETS[0][1])]
print("\n[Ready] Click THIS terminal, then W/A/S/D to move. "
"1-6 change mode, 9/0 speed, Esc quit.\n", flush=True)
# Sim hub publishes wireless_remote bytes that can fight terminal WASD.
use_joystick = not args.keyboard_only and (args.gamepad or args.ip is not None)
with RawKeyboard() as kb:
try:
gc.disable()
gc_timer = 0.0
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
time.sleep(1.0)
last_status = time.time() - 2.1
loop_t = enc_t = dec_t = obs_t = act_t = []
slow_n = blend_n = 0
stall_src = ""
did_blend = False
prev_end = time.time()
t_start = time.time()
log_path = "/tmp/sonic_pose_log.csv"
jnames = [m.name for m in G1_29_JointIndex]
log_ctx = open(log_path, "w") if args.log_csv else None
if log_ctx:
log_ctx.write("t,step,cursor,ts,blend,mode," +
",".join(f"q{i}" for i in range(29)) + "," +
",".join(f"ref{i}" for i in range(29)) + "," +
",".join(f"act{i}" for i in range(29)) +
",delta_max,action_norm,token_norm\n")
try:
while not robot._shutdown_event.is_set():
t0 = time.time()
if drain_keyboard(kb, ms, controller):
break
obs = robot.get_observation()
t_obs = time.time()
obs_t.append(1000 * (t_obs - t0))
if not obs:
runtime.tick({}, use_joystick=False)
time.sleep(max(0.0, CONTROL_DT - (time.time() - t0)))
continue
step_before = runtime.step
t_step = time.time()
action = runtime.tick(obs, use_joystick=use_joystick)
step_ms = 1000 * (time.time() - t_step)
do_enc = step_before % 5 == 0
(enc_t if do_enc else dec_t).append(step_ms)
t_act = time.time()
robot.send_action(action)
act_t.append(1000 * (time.time() - t_act))
if log_ctx and runtime.step % 5 == 0:
t_rel = time.time() - t_start
q_r = np.array([obs.get(f"{n}.q", 0) for n in jnames])
a_v = np.array([action.get(f"{n}.q", 0) for n in jnames])
cur, ts = controller.ref_cursor, controller.motion_timesteps
q_ref = controller.motion_joint_positions[min(cur, ts - 1)] if ts > 0 else np.zeros(29)
log_ctx.write(f"{t_rel:.4f},{runtime.step},{cur},{ts},{int(did_blend)},{ms.mode}," +
",".join(f"{v:.6f}" for v in q_r) + "," +
",".join(f"{v:.6f}" for v in q_ref) + "," +
",".join(f"{v:.6f}" for v in a_v) + "," +
f"{np.max(np.abs(a_v - q_r)):.6f},"
f"{np.linalg.norm(a_v):.6f},"
f"{np.linalg.norm(controller.token):.6f}\n")
did_blend = False
now = time.time()
loop_ms = 1000 * (now - t0)
if loop_ms > 50:
stall_src = (f"[STALL] {loop_ms:.0f}ms: "
f"obs={obs_t[-1]:.0f} step={step_ms:.0f} act={act_t[-1]:.0f}")
if loop_ms > CONTROL_DT * 1500:
slow_n += 1
if now - last_status > 2.0:
def _avg(lst):
return sum(lst) / len(lst) if lst else 0
hz = 1000 / _avg(loop_t) if _avg(loop_t) else 0
print(f"\r {ms.status_line()} step={runtime.step} "
f"ref={controller.ref_cursor}/{controller.motion_timesteps} "
f"loop={_avg(loop_t):.1f}ms(max={max(loop_t, default=0):.1f}) hz={hz:.0f} "
f"enc={_avg(enc_t):.1f} dec={_avg(dec_t):.1f} obs={_avg(obs_t):.1f} "
f"slow={slow_n} blends={blend_n}", end="", flush=True)
if stall_src:
print(f"\n {stall_src}")
stall_src = ""
last_status = now
loop_t = enc_t = dec_t = obs_t = act_t = []
slow_n = blend_n = 0
prev_end = time.time()
gc_timer += CONTROL_DT
if gc_timer >= 10.0:
gc.collect()
gc_timer = 0.0
loop_t.append(loop_ms)
time.sleep(max(0.0, CONTROL_DT - (time.time() - t0)))
finally:
if log_ctx:
log_ctx.close()
except KeyboardInterrupt:
pass
finally:
gc.enable()
if args.log_csv:
print(f"\n[Log] Saved to {log_path}")
runtime.shutdown()
print("\nStopping...")
if robot.is_connected:
robot.disconnect()
print("Done.")
if __name__ == "__main__":
main()
@@ -54,6 +54,7 @@ from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
@@ -274,12 +275,11 @@ class LanguageColumnsWriter:
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
# Atomic replace: write to a sibling tmp path and rename so a crash
# mid-write can't leave a half-written shard that ``pq.read_table``
# would then fail to open. ``Path.replace`` is atomic on POSIX +
# Windows when source and target sit on the same filesystem.
# Re-emit one row group per episode (a bulk pq.write_table would collapse
# them into one). Write to a sibling tmp path and atomically rename so a
# crash mid-write can't leave a half-written shard.
tmp_path = path.with_suffix(path.suffix + ".tmp")
pq.write_table(new_table, tmp_path)
write_table_one_row_group_per_episode(new_table, tmp_path)
tmp_path.replace(path)
def _materialize_table(
-2
View File
@@ -39,8 +39,6 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
# Fraction of episodes held out per task for offline evaluation (0.0 = disabled).
eval_split: float = 0.0
def __post_init__(self) -> None:
if self.episodes is not None:
+1 -6
View File
@@ -100,13 +100,8 @@ class TrainPipelineConfig(HubMixin):
prefetch_factor: int = 4
persistent_workers: bool = True
steps: int = 100_000
# Run policy in the simulation environment every N steps to measure reward/success (0 = disabled).
env_eval_freq: int = 20_000
eval_freq: int = 20_000
log_freq: int = 200
# Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0.
eval_steps: int = 0
# Cap on total eval samples, split uniformly across tasks (0 = use all held-out data).
max_eval_samples: int = 0
tolerance_s: float = 1e-4
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
+1 -2
View File
@@ -35,7 +35,7 @@ from .dataset_tools import (
remove_feature,
split_dataset,
)
from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
@@ -89,7 +89,6 @@ __all__ = [
"get_feature_stats",
"load_episodes",
"make_dataset",
"make_train_eval_datasets",
"merge_datasets",
"modify_features",
"modify_tasks",
+9
View File
@@ -32,6 +32,7 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
from .io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
to_parquet_one_row_group_per_episode,
to_parquet_with_hf_images,
write_info,
write_stats,
@@ -551,6 +552,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
aggr_root=dst_meta.root,
hf_features=hf_features,
concatenate=concatenate_data,
one_row_group_per_episode=True,
)
# Record the mapping from source to actual destination
@@ -628,6 +630,7 @@ def append_or_create_parquet_file(
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
concatenate: bool = True,
one_row_group_per_episode: bool = False,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -645,6 +648,8 @@ def append_or_create_parquet_file(
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one.
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
the episodes-metadata parquet (already one row per episode).
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -657,6 +662,8 @@ def append_or_create_parquet_file(
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(df, dst_path)
else:
df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file)
@@ -683,6 +690,8 @@ def append_or_create_parquet_file(
if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(final_df, target_path)
else:
final_df.to_parquet(target_path)
-79
View File
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from pprint import pformat
import torch
@@ -131,81 +130,3 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset
def make_train_eval_datasets(
cfg: TrainPipelineConfig,
) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]:
"""Create train and optional eval datasets by splitting episodes based on eval_split.
The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation.
If eval_split == 0.0, returns (full_dataset, None).
"""
full_dataset = make_dataset(cfg)
if cfg.dataset.eval_split == 0.0:
return full_dataset, None
base_episodes = (
full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes))
)
episode_tasks = full_dataset.meta.episodes["tasks"]
task_to_episodes: dict[str, list[int]] = {}
for ep_idx in base_episodes:
task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else ""
task_to_episodes.setdefault(task_key, []).append(ep_idx)
train_episodes, eval_episodes = [], []
for eps in task_to_episodes.values():
n_eval = math.ceil(len(eps) * cfg.dataset.eval_split)
train_episodes.extend(eps[: len(eps) - n_eval])
eval_episodes.extend(eps[len(eps) - n_eval :])
if not train_episodes:
raise ValueError(
f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total."
)
logging.info(
f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval "
f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)"
)
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta)
train_image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
train_dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=train_episodes,
delta_timestamps=delta_timestamps,
image_transforms=train_image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
eval_dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=eval_episodes,
delta_timestamps=delta_timestamps,
image_transforms=None,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
if cfg.dataset.use_imagenet_stats:
for ds in (train_dataset, eval_dataset):
for key in ds.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return train_dataset, eval_dataset
+38 -9
View File
@@ -20,6 +20,7 @@ import datasets
import numpy as np
import pandas
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
@@ -270,21 +271,49 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
return items_dict
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
"""Write ``table`` with one parquet row group per episode (in episode order).
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
mirroring the recording writer. ``table`` must carry a contiguous
``episode_index`` column.
"""
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
try:
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
finally:
writer.close()
def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
Images are embedded into the arrow table first (``ParquetWriter.write_table``
does not embed external image files like ``Dataset.to_parquet`` does).
``features`` types image columns as ``Image()`` in the parquet schema.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds.to_parquet(path)
ds = embed_images(ds)
table = ds.with_format("arrow")[:]
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
# No episode boundaries to align row groups to — keep a single write.
pq.write_table(table, str(path))
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
table = pa.Table.from_pandas(df, preserve_index=False)
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
pq.write_table(table, str(path))
def item_to_torch(item: dict) -> dict:
-2
View File
@@ -474,8 +474,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
if reader.hf_dataset is None:
# One-shot load after finalize()
reader.load_and_activate()
if reader._absolute_to_relative_idx is not None and idx in reader._absolute_to_relative_idx:
idx = reader._absolute_to_relative_idx[idx]
return reader.get_item(idx)
def select_columns(self, column_names: str | list[str]):
+5 -3
View File
@@ -70,19 +70,21 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]],
*,
use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `use_videos` and `patterns`, and finally
(image or state), filters them based on `exclude_images` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset.
Args:
pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: If False, image features are excluded.
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
exclude_images: If True, image features are dropped entirely from the output.
patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter.
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
)
# 2. Apply filtering rules.
if is_image and not use_videos:
if is_image and exclude_images:
continue
if not is_image and not should_keep(key, compiled_patterns):
continue
@@ -18,7 +18,8 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot
@@ -27,7 +28,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__)
class BiOpenArmFollower(Robot):
class BiOpenArmFollower(BimanualMixin, Robot):
"""
Bimanual OpenArm Follower Arms
"""
@@ -39,15 +40,17 @@ class BiOpenArmFollower(Robot):
super().__init__(config)
self.config = config
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
# will only open the cameras assigned to it. Per-arm cameras are used
# as fallback when top-level cameras are empty.
if config.cameras:
left_cameras = config.cameras
right_cameras = {}
else:
left_cameras = config.left_arm_config.cameras
right_cameras = config.right_arm_config.cameras
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None,
@@ -56,7 +59,7 @@ class BiOpenArmFollower(Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_cameras,
cameras=left_arm_cameras,
side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
@@ -75,7 +78,7 @@ class BiOpenArmFollower(Robot):
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=right_cameras,
cameras=config.right_arm_config.cameras,
side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
@@ -95,22 +98,19 @@ class BiOpenArmFollower(Robot):
@property
def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return {
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
}
@property
def _cameras_ft(self) -> dict[str, tuple]:
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
# "right_wrist"), so we merge them directly — unlike motors which need the
# left_/right_ prefix to disambiguate identical per-arm joint names.
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -120,27 +120,6 @@ class BiOpenArmFollower(Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -148,21 +127,15 @@ class BiOpenArmFollower(Robot):
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
obs_dict: RobotObservation = {}
# Camera keys that should NOT get the arm prefix (they already have unique names)
left_cam_keys = set(self.left_arm.cameras.keys())
right_cam_keys = set(self.right_arm.cameras.keys())
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
right_obs = self.right_arm.get_observation()
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
return obs_dict
@@ -189,9 +162,4 @@ class BiOpenArmFollower(Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@@ -32,5 +32,7 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras shared across both arms.
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot
@@ -27,7 +28,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__)
class BiRebotB601Follower(Robot):
class BiRebotB601Follower(BimanualMixin, Robot):
"""Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
@@ -41,6 +42,18 @@ class BiRebotB601Follower(Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -49,7 +62,7 @@ class BiRebotB601Follower(Robot):
dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=config.left_arm_config.cameras,
cameras=left_arm_cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
@@ -86,10 +99,12 @@ class BiRebotB601Follower(Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
}
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -99,32 +114,13 @@ class BiRebotB601Follower(Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
obs_dict: RobotObservation = {}
for k, v in self.left_arm.get_observation().items():
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm.get_observation().items():
obs_dict[f"right_{k}"] = v
return obs_dict
@check_if_not_connected
@@ -143,8 +139,3 @@ class BiRebotB601Follower(Robot):
**{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()},
}
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig
@@ -27,3 +29,8 @@ class BiRebotB601FollowerConfig(RobotConfig):
left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,7 +18,8 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..robot import Robot
from ..so_follower import SOFollower, SOFollowerRobotConfig
@@ -27,7 +28,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__)
class BiSOFollower(Robot):
class BiSOFollower(BimanualMixin, Robot):
"""
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -39,6 +40,18 @@ class BiSOFollower(Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = SOFollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -46,7 +59,7 @@ class BiSOFollower(Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
use_degrees=config.left_arm_config.use_degrees,
cameras=config.left_arm_config.cameras,
cameras=left_arm_cameras,
)
right_arm_config = SOFollowerRobotConfig(
@@ -77,13 +90,12 @@ class BiSOFollower(Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
left_arm_cameras_ft = self.left_arm._cameras_ft
right_arm_cameras_ft = self.right_arm._cameras_ft
return {
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -93,42 +105,21 @@ class BiSOFollower(Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict = {}
obs_dict: RobotObservation = {}
# Add "left_" prefix
left_obs = self.left_arm.get_observation()
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Add "right_" prefix
right_obs = self.right_arm.get_observation()
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
return obs_dict
@@ -151,8 +142,3 @@ class BiSOFollower(Robot):
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig
from ..so_follower import SOFollowerConfig
@@ -27,3 +29,8 @@ class BiSOFollowerConfig(RobotConfig):
left_arm_config: SOFollowerConfig
right_arm_config: SOFollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -68,6 +68,6 @@ class UnitreeG1Config(RobotConfig):
# Compensates for gravity on the unitree's arms using the arm ik solver
gravity_compensation: bool = False
# Lower-body controller class name, e.g. "GrootLocomotionController" or
# "HolosomaLocomotionController". None disables it.
# Locomotion controller class name, e.g. "GrootLocomotionController",
# "HolosomaLocomotionController", or "SonicWholeBodyController". None disables it.
controller: str | None = None
@@ -0,0 +1,8 @@
"""Unitree G1 locomotion controllers (Groot, Holosoma, SONIC)."""
__all__ = [
"GrootLocomotionController",
"HolosomaLocomotionController",
"SonicWholeBodyController",
"SonicRuntime",
]
@@ -21,7 +21,7 @@ import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from .g1_utils import (
from lerobot.robots.unitree_g1.g1_utils import (
REMOTE_AXES,
REMOTE_BUTTONS,
G1_29_JointIndex,
@@ -22,7 +22,7 @@ import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from .g1_utils import (
from lerobot.robots.unitree_g1.g1_utils import (
REMOTE_AXES,
G1_29_JointArmIndex,
G1_29_JointIndex,
@@ -0,0 +1,913 @@
"""SONIC planner pipeline: ONNX enc/dec/planner, movement state, and input helpers."""
import math
import queue
import select
import struct
import sys
import termios
import threading
import time
import tty
from dataclasses import dataclass
from enum import IntEnum
import numpy as np
import onnxruntime as ort
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
# ── Constants ────────────────────────────────────────────────────────────────
DEFAULT_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
0.0, 0.0, 0.0,
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
], dtype=np.float32)
NATURAL_FREQ = 10.0 * 2.0 * np.pi
ARMATURE = {"5020": 0.003609725, "7520_14": 0.010177520, "7520_22": 0.025101925, "4010": 0.00425}
EFFORT = {"5020": 25.0, "7520_14": 88.0, "7520_22": 139.0, "4010": 5.0}
def _action_scale(k):
return 0.25 * EFFORT[k] / (ARMATURE[k] * NATURAL_FREQ**2)
_J = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
["7520_14","5020","5020"] + \
["5020","5020","5020","5020","5020","4010","4010"] * 2
ACTION_SCALE = np.array([_action_scale(k) for k in _J], dtype=np.float32)
CONTROL_DT = 0.02
DEFAULT_HEIGHT = 0.788740
TOKEN_DIM = 64
ENCODER_UPDATE_EVERY = 5
DEBUG_PRINT_EVERY = 100
MOTION_LOOK_AHEAD_STEPS = 2
INITIAL_RANDOM_SEED = 1234
MIN_TOKENS, MAX_TOKENS = 6, 16
K = MAX_TOKENS - MIN_TOKENS + 1
DEADZONE = 0.05
BLEND_FRAMES = 8
REPLAN_INTERVAL = {
"running": 0.1, "crawling": 0.2, "boxing": 1.0, "default": 1.0
}
ISAACLAB_TO_MUJOCO = np.array([
0, 3, 6, 9, 13, 17, 1, 4, 7, 10, 14, 18, 2, 5, 8,
11, 15, 19, 21, 23, 25, 27, 12, 16, 20, 22, 24, 26, 28
], dtype=np.int32)
MUJOCO_TO_ISAACLAB = np.array([
0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 22, 4, 10,
16, 23, 5, 11, 17, 24, 18, 25, 19, 26, 20, 27, 21, 28
], dtype=np.int32)
def _to_mujoco(a): return a[MUJOCO_TO_ISAACLAB]
def _to_runtime(a): r = np.zeros(29, np.float32); r[MUJOCO_TO_ISAACLAB] = a; return r
DEFAULT_ANGLES_MUJOCO = _to_mujoco(DEFAULT_ANGLES)
ENCODER_STANDING_REF = DEFAULT_ANGLES.copy()
LOWER_BODY_IL = np.array([0,3,6,9,13,17,1,4,7,10,14,18], dtype=np.int32)
WRIST_IL = np.array([23,24,25,26,27,28], dtype=np.int32)
VR_TARGET_DEF = np.zeros(9, dtype=np.float32)
VR_ORN_DEF = np.array([1,0,0,0,1,0,0,0,1,0,0,0], dtype=np.float32)
SMPL_DEF = np.zeros(720, dtype=np.float32)
# ── PD gains ─────────────────────────────────────────────────────────────────
def compute_kp_kd():
s = lambda k: ARMATURE[k] * NATURAL_FREQ**2
d = lambda k: 2.0 * 2.0 * ARMATURE[k] * NATURAL_FREQ
_kp_keys = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
["7520_14","5020","5020"] + \
["5020","5020","5020","5020","5020","4010","4010"] * 2
_kd_keys = _kp_keys
_double = {4,5,10,11,13,14} # ankle + waist indices with factor 2
kp = np.array([2*s(k) if i in _double else s(k) for i,k in enumerate(_kp_keys)], dtype=np.float32)
kd = np.array([2*d(k) if i in _double else d(k) for i,k in enumerate(_kd_keys)], dtype=np.float32)
return kp, kd
_kp_kd = compute_kp_kd # backward-compatible alias
def lowstate_to_obs(lowstate) -> dict:
"""Build a robot observation dict from Unitree lowstate."""
obs: dict = {}
for motor in G1_29_JointIndex:
idx = motor.value
obs[f"{motor.name}.q"] = float(lowstate.motor_state[idx].q)
obs[f"{motor.name}.dq"] = float(lowstate.motor_state[idx].dq)
quat = lowstate.imu_state.quaternion
obs["imu.quat.w"] = float(quat[0])
obs["imu.quat.x"] = float(quat[1])
obs["imu.quat.y"] = float(quat[2])
obs["imu.quat.z"] = float(quat[3])
gyro = lowstate.imu_state.gyroscope
obs["imu.gyro.x"] = float(gyro[0])
obs["imu.gyro.y"] = float(gyro[1])
obs["imu.gyro.z"] = float(gyro[2])
wr = getattr(lowstate, "wireless_remote", None)
if wr is not None:
obs["wireless_remote"] = bytes(wr) if not isinstance(wr, (bytes, bytearray)) else wr
return obs
# ── Quaternion helpers ────────────────────────────────────────────────────────
def quat_conj(q):
return np.array([q[0], -q[1], -q[2], -q[3]], dtype=np.float32)
def quat_mul(q1, q2):
w1,x1,y1,z1 = q1; w2,x2,y2,z2 = q2
return np.array([
w1*w2 - x1*x2 - y1*y2 - z1*z2,
w1*x2 + x1*w2 + y1*z2 - z1*y2,
w1*y2 - x1*z2 + y1*w2 + z1*x2,
w1*z2 + x1*y2 - y1*x2 + z1*w2,
], dtype=np.float32)
def gravity_dir(q):
q = q / (np.linalg.norm(q) + 1e-8)
qv = np.array([0, 0, 0, -1], dtype=np.float32)
return quat_mul(quat_mul(quat_conj(q), qv), q)[1:]
def quat_to_6d(q):
w,x,y,z = q
return np.array([
1-2*(y*y+z*z), 2*(x*y-z*w),
2*(x*y+z*w), 1-2*(x*x+z*z),
2*(x*z-y*w), 2*(y*z+x*w),
], dtype=np.float32)
def calc_heading(q):
w,x,y,z = q
return float(np.arctan2(2*(x*y + w*z), 1-2*(y*y+z*z)))
def heading_quat(q, sign=1.0):
a = sign * calc_heading(q) / 2.0
return np.array([np.cos(a), 0, 0, np.sin(a)], dtype=np.float64)
heading_quat_inv = lambda q: heading_quat(q, -1.0)
def quat_slerp(q0, q1, t):
q0 = q0 / (np.linalg.norm(q0)+1e-12); q1 = q1 / (np.linalg.norm(q1)+1e-12)
dot = float(np.dot(q0, q1))
if dot < 0: q1, dot = -q1, -dot
dot = min(dot, 1.0)
if dot > 0.9995:
r = q0 + t*(q1-q0); return r/(np.linalg.norm(r)+1e-12)
th = np.arccos(dot); st = np.sin(th)
return (np.sin((1-t)*th)/st)*q0 + (np.sin(t*th)/st)*q1
def quat_slerp_batch(q0, q1, t):
q0 = q0 / (np.linalg.norm(q0,axis=1,keepdims=True)+1e-12)
q1 = q1 / (np.linalg.norm(q1,axis=1,keepdims=True)+1e-12)
dot = np.sum(q0*q1, axis=1); neg = dot<0
q1=q1.copy(); q1[neg]=-q1[neg]; dot[neg]=-dot[neg]; dot=np.clip(dot,-1,1)
lin = dot>0.9995; th=np.arccos(dot); st=np.where(np.sin(th)==0,1,np.sin(th))
c0=np.sin((1-t)*th)/st; c1=np.sin(t*th)/st
c0[lin]=1-t[lin]; c1[lin]=t[lin]
r = c0[:,None]*q0 + c1[:,None]*q1
return r / (np.linalg.norm(r,axis=1,keepdims=True)+1e-12)
# ── Locomotion modes ──────────────────────────────────────────────────────────
class LocomotionMode(IntEnum):
IDLE=0; SLOW_WALK=1; WALK=2; RUN=3; SQUAT=4; KNEEL_TWO_LEGS=5; KNEEL=6
LYING_FACE_DOWN=7; CRAWLING=8; IDLE_BOXING=9; WALK_BOXING=10
LEFT_PUNCH=11; RIGHT_PUNCH=12; RANDOM_PUNCH=13; ELBOW_CRAWLING=14
LEFT_HOOK=15; RIGHT_HOOK=16; FORWARD_JUMP=17; STEALTH_WALK=18
INJURED_WALK=19; LEDGE_WALKING=20; OBJECT_CARRYING=21; STEALTH_WALK_2=22
HAPPY_DANCE_WALK=23; ZOMBIE_WALK=24; GUN_WALK=25; SCARE_WALK=26
LM = LocomotionMode
MOTION_SETS = [
("Standing", [LM.SLOW_WALK, LM.WALK, LM.RUN, LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK]),
("Squat / Low", [LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.CRAWLING, LM.ELBOW_CRAWLING]),
("Boxing", [LM.IDLE_BOXING, LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK]),
("Styled Walks", [LM.LEDGE_WALKING, LM.OBJECT_CARRYING, LM.STEALTH_WALK_2,
LM.HAPPY_DANCE_WALK, LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK]),
]
STATIC_MODES = {LM.IDLE, LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.LYING_FACE_DOWN, LM.IDLE_BOXING}
STANDING_MODES = {LM.IDLE, LM.SLOW_WALK, LM.WALK, LM.RUN, LM.IDLE_BOXING, LM.WALK_BOXING,
LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK,
LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK, LM.LEDGE_WALKING,
LM.OBJECT_CARRYING, LM.STEALTH_WALK_2, LM.HAPPY_DANCE_WALK,
LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK}
BOXING_MODES = {LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}
SPEED_RANGES = {LM.SLOW_WALK:(0.2,0.8), LM.WALK:(0.8,1.5), LM.RUN:(1.5,3.0),
LM.CRAWLING:(0.4,1.0), LM.ELBOW_CRAWLING:(0.7,1.0)}
def clamp_mode_params(ms):
m = LM(ms.mode)
ms.height = -1.0 if m in STANDING_MODES else max(0.1, min(0.8, ms.height if ms.height>=0 else 0.2))
if m in STATIC_MODES:
ms.speed = -1.0
elif m in SPEED_RANGES:
lo, hi = SPEED_RANGES[m]
ms.speed = max(lo, min(hi, ms.speed if ms.speed>=0 else lo))
elif m in BOXING_MODES:
ms.speed = max(0.7, min(1.5, ms.speed if ms.speed>=0 else 0.7))
else:
ms.speed = -1.0
def replan_interval(mode):
m = LM(mode)
if m == LM.RUN: return REPLAN_INTERVAL["running"]
if m == LM.CRAWLING: return REPLAN_INTERVAL["crawling"]
if m in {LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}:
return REPLAN_INTERVAL["boxing"]
return REPLAN_INTERVAL["default"]
def _ort_providers(force_cpu: bool = False) -> list[str]:
"""Prefer CUDA for enc/dec/planner (matches deploy when onnxruntime-gpu is installed)."""
avail = ort.get_available_providers()
if not force_cpu and "CUDAExecutionProvider" in avail:
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
return ["CPUExecutionProvider"]
# ── Movement state ────────────────────────────────────────────────────────────
@dataclass
class MovementState:
mode: int = LM.SLOW_WALK # not IDLE — walking modes respond to WASD
speed: float = -1.0
height: float = -1.0
facing_angle: float = 0.0
movement_angle: float = 0.0
has_movement: bool = False
motion_set_idx: int = 0
needs_replan: bool = False
@property
def movement_direction(self):
if not self.has_movement: return (0.0, 0.0, 0.0)
return (math.cos(self.movement_angle), math.sin(self.movement_angle), 0.0)
@property
def facing_direction(self):
return (math.cos(self.facing_angle), math.sin(self.facing_angle), 0.0)
def status_line(self):
return (f"[{MOTION_SETS[self.motion_set_idx][0]}] mode={self.mode}({LM(self.mode).name}) "
f"spd={'default' if self.speed<0 else f'{self.speed:.1f}'} "
f"hgt={'default' if self.height<0 else f'{self.height:.2f}'} "
f"facing={math.degrees(self.facing_angle):.0f}° "
f"{'moving' if self.has_movement else 'still'}")
@dataclass
class MovementSnapshot:
mode: int = 0
speed: float = -1.0
height: float = -1.0
movement: tuple[float, float, float] = (0.0, 0.0, 0.0)
facing: tuple[float, float, float] = (1.0, 0.0, 0.0)
def _snapshot_ms(ms: MovementState) -> MovementSnapshot:
md, fd = ms.movement_direction, ms.facing_direction
return MovementSnapshot(ms.mode, ms.speed, ms.height, (md[0], md[1], md[2]), (fd[0], fd[1], fd[2]))
def should_replan_request(ms: MovementState, last: MovementSnapshot, replan_timer: float, step: int) -> bool:
"""Match C++ G1Deploy::Planner replan triggers (g1_deploy_onnx_ref.cpp)."""
if step <= 0:
return False
if ms.needs_replan:
return True
md, fd = ms.movement_direction, ms.facing_direction
facing_changed = fd != last.facing
height_changed = ms.height != last.height
mode_changed = ms.mode != last.mode
speed_changed = ms.speed != last.speed
dir_changed = md != last.movement
is_static = LM(ms.mode) in STATIC_MODES
if mode_changed or facing_changed or height_changed:
return True
time_to_replan = replan_timer >= replan_interval(ms.mode)
if not is_static and (speed_changed or dir_changed or (time_to_replan and ms.speed != 0)):
return True
return False
# ── Encoder / Decoder ─────────────────────────────────────────────────────────
class StandingEncoderDecoder:
def __init__(self, encoder, decoder):
self.encoder, self.decoder = encoder, decoder
self.encoder_input = encoder.get_inputs()[0].name
self.decoder_input = decoder.get_inputs()[0].name
enc_dim = int(encoder.get_inputs()[0].shape[1])
dec_dim = int(decoder.get_inputs()[0].shape[1])
if enc_dim != 1762 or dec_dim != 994:
raise RuntimeError(f"Unexpected dims encoder={enc_dim}, decoder={dec_dim}")
self.token = np.zeros(TOKEN_DIM, np.float32)
self.last_action_mj = np.zeros(29, np.float32)
self.h_q_mj = [np.zeros(29, np.float32)] * 10
self.h_dq_mj = [np.zeros(29, np.float32)] * 10
self.h_ang = [np.zeros(3, np.float32)] * 10
self.h_act_mj = [np.zeros(29, np.float32)] * 10
self.h_quat = [np.array([1,0,0,0], np.float32)] * 10
self.init_base_quat = np.array([1,0,0,0], np.float32)
self.init_ref_quat = np.array([1,0,0,0], np.float32)
self._heading_init = False
self.encode_mode = 0
self.vr_3point_local_target = VR_TARGET_DEF.copy()
self.vr_3point_local_orn_target = VR_ORN_DEF.copy()
self.smpl_joints_10frame_step1 = SMPL_DEF.copy()
self.set_zero_reference()
def update_history(self, q, dq, ang, quat):
quat = quat / (np.linalg.norm(quat)+1e-8)
q_mj = _to_mujoco(q); dq_mj = _to_mujoco(dq)
self.h_q_mj = [q_mj - DEFAULT_ANGLES_MUJOCO] + self.h_q_mj[:-1]
self.h_dq_mj = [dq_mj] + self.h_dq_mj[:-1]
self.h_ang = [ang.copy()] + self.h_ang[:-1]
self.h_act_mj = [self.last_action_mj.copy()] + self.h_act_mj[:-1]
self.h_quat = [quat.copy()] + self.h_quat[:-1]
if not self._heading_init:
self.init_base_quat = quat.copy(); self._heading_init = True
def _heading_quat(self, q):
h = calc_heading(q) / 2.0
return np.array([np.cos(h), 0, 0, np.sin(h)], np.float32)
def _heading_quat_inv(self, q):
h = calc_heading(q) / 2.0
return np.array([np.cos(-h), 0, 0, np.sin(-h)], np.float32)
def _anchor_6d(self, base_quat, ref_quat=None):
if ref_quat is None: ref_quat = self.init_ref_quat
delta = quat_mul(self._heading_quat(self.init_base_quat), self._heading_quat_inv(self.init_ref_quat))
new_ref = quat_mul(delta, ref_quat)
return quat_to_6d(quat_mul(quat_conj(base_quat), new_ref))
def set_zero_reference(self):
self.motion_joint_positions = [ENCODER_STANDING_REF.copy()]
self.motion_joint_velocities = [np.zeros(29, np.float32)]
self.motion_body_quats = [np.array([1,0,0,0], np.float32)]
self.motion_body_z = [DEFAULT_HEIGHT]
self.motion_timesteps = 1
self.freeze_ref_frame = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
def build_encoder_obs(self):
obs = np.zeros(1762, np.float32)
obs[0] = float(self.encode_mode)
rf = min(self.freeze_ref_frame, self.motion_timesteps - 1)
ref_pos, ref_quat = self.motion_joint_positions[rf], self.motion_body_quats[rf]
if self.encode_mode == 0:
for f in range(10):
obs[4+29*f:4+29*(f+1)] = ref_pos
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
elif self.encode_mode == 1:
ref_lower = ref_pos[LOWER_BODY_IL]
for f in range(10):
obs[661+12*f:661+12*(f+1)] = ref_lower
obs[901:910] = self.vr_3point_local_target
obs[910:922] = self.vr_3point_local_orn_target
obs[595:601] = self._anchor_6d(self.h_quat[0], ref_quat)
elif self.encode_mode == 2:
obs[922:1642] = self.smpl_joints_10frame_step1
for f in range(10):
obs[1642+6*f:1642+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
obs[1702+6*f:1702+6*(f+1)] = ref_pos[WRIST_IL]
else:
raise RuntimeError(f"Unsupported encoder mode: {self.encode_mode}")
return obs
def build_decoder_obs(self):
obs = np.zeros(994, np.float32); off = 0
obs[off:off+64] = self.token; off += 64
for h, sz in [(list(reversed(self.h_ang)),3), (list(reversed(self.h_q_mj)),29),
(list(reversed(self.h_dq_mj)),29), (list(reversed(self.h_act_mj)),29)]:
for f in range(10): obs[off:off+sz] = h[f]; off += sz
for q in reversed(self.h_quat):
obs[off:off+3] = gravity_dir(q); off += 3
assert off == 994, f"Decoder obs mismatch: {off}"
return obs
def run_encoder(self):
return self.encoder.run(None, {self.encoder_input: self.build_encoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
def step(self, robot_obs, update_encoder, debug=False):
jnames = [m.name for m in G1_29_JointIndex]
q = np.array([robot_obs.get(f"{n}.q", DEFAULT_ANGLES[m.value]) for m,n in zip(G1_29_JointIndex,jnames)], np.float32)
dq = np.array([robot_obs.get(f"{n}.dq", 0.0) for n in jnames], np.float32)
quat = np.array([robot_obs.get("imu.quat.w",1), robot_obs.get("imu.quat.x",0),
robot_obs.get("imu.quat.y",0), robot_obs.get("imu.quat.z",0)], np.float32)
ang = np.array([robot_obs.get(f"imu.gyro.{a}",0) for a in "xyz"], np.float32)
self.update_history(q, dq, ang, quat)
if update_encoder: self.token = self.run_encoder()
action_mj = self.decoder.run(None, {self.decoder_input: self.build_decoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
self.last_action_mj = action_mj.copy()
target = DEFAULT_ANGLES + action_mj[ISAACLAB_TO_MUJOCO] * ACTION_SCALE
if debug:
delta = target - q
print(f"token_norm={np.linalg.norm(self.token):.4f} action_norm={np.linalg.norm(action_mj):.4f} "
f"delta_max={np.max(np.abs(delta)):.4f} delta_rms={np.sqrt(np.mean(delta**2)):.4f}")
return {f"{m.name}.q": float(target[m.value]) for m in G1_29_JointIndex}
def print_input_diagnostics(self):
print("\n[Diag] Standing reference checks")
names = {0:"g1", 1:"teleop", 2:"smpl"}
print(f" encoder mode: {self.encode_mode} ({names.get(self.encode_mode,'unknown')})")
print(f" DEFAULT_ANGLES range: [{DEFAULT_ANGLES.min():+.4f}, {DEFAULT_ANGLES.max():+.4f}]")
print(f" anchor_6d(identity): {self._anchor_6d(np.array([1,0,0,0],np.float32), np.array([1,0,0,0],np.float32))}")
print(f" gravity(identity): {gravity_dir(np.array([1,0,0,0],np.float32))} (expect [0,0,-1])")
dec0 = self.build_decoder_obs()
print(f" decoder q-delta max: {np.max(np.abs(dec0[94:384])):.6f}")
print(f" decoder dq max: {np.max(np.abs(dec0[384:674])):.6f}")
# ── Planner motion buffer ─────────────────────────────────────────────────────
class PlannerMotion:
def __init__(self, max_frames=1500):
self.timesteps = 0
self.joint_positions = np.zeros((max_frames, 29), np.float64)
self.joint_velocities = np.zeros((max_frames, 29), np.float64)
self.body_positions = np.zeros((max_frames, 3), np.float64)
self.body_quaternions = np.zeros((max_frames, 4), np.float64)
self.body_quaternions[:, 0] = 1.0
# ── Subprocess planner ────────────────────────────────────────────────────────
def _resample_30_to_50(qpos, n30):
t50 = int(np.floor(n30 / 30.0 * 50))
f30 = np.arange(t50) / 50.0 * 30.0
f0 = np.floor(f30).astype(int)
f1 = np.minimum(f0+1, n30-1)
frac, w0 = (f30-f0).astype(np.float64), None
w0 = 1.0 - frac
jp = (w0[:,None]*qpos[f0,7:36] + frac[:,None]*qpos[f1,7:36])[:,MUJOCO_TO_ISAACLAB]
jv = np.zeros_like(jp)
if t50 >= 2: jv[:t50-1] = (jp[1:] - jp[:-1]) * 50.0; jv[-1] = jv[-2]
return {
"timesteps": t50,
"joint_positions": jp,
"joint_velocities": jv,
"body_positions": w0[:,None]*qpos[f0,:3] + frac[:,None]*qpos[f1,:3],
"body_quaternions": quat_slerp_batch(qpos[f0,3:7], qpos[f1,3:7], frac),
}
def _build_planner_inputs(ctx, ms_dict, version, seed):
inp = {
"context_mujoco_qpos": ctx.astype(np.float32).reshape(1,4,36),
"target_vel": np.array([ms_dict["speed"]], np.float32),
"mode": np.array([ms_dict["mode"]], np.int64),
"movement_direction": np.array(ms_dict["movement_direction"], np.float32).reshape(1,3),
"facing_direction": np.array(ms_dict["facing_direction"], np.float32).reshape(1,3),
"random_seed": np.array([seed], np.int64),
}
if version >= 1:
# TensorRT deploy: allow 911 prediction tokens only (indices 35 for MIN_TOKENS=6).
allowed = np.zeros((1, K), np.int64)
if K >= 6:
allowed[0, 3:6] = 1
inp.update({
"height": np.array([ms_dict["height"]], np.float32),
"has_specific_target": np.array([[0]], np.int64),
"specific_target_positions": np.zeros((1,4,3), np.float32),
"specific_target_headings": np.zeros((1,4), np.float32),
"allowed_pred_num_tokens": allowed,
})
return inp
def _planner_worker(path, req_q, res_q, stop_evt, version, seed, use_gpu):
so = ort.SessionOptions(); so.log_severity_level = 3
providers = _ort_providers(force_cpu=not use_gpu)
sess = ort.InferenceSession(path, sess_options=so, providers=providers)
while not stop_evt.is_set():
try: ctx, gf, ms_dict = req_q.get(timeout=0.05)
except Exception: continue
try:
inp = _build_planner_inputs(ctx, ms_dict, version, seed)
t0 = time.time()
qpos_out, num_pred = sess.run(None, inp)
t_inf = time.time()
n = int(num_pred.flat[0])
qpos = qpos_out[0,:n]
if np.any(np.isnan(qpos)): continue
motion = _resample_30_to_50(qpos, n)
motion["gen_frame"] = gf
print(f"[Planner] inf={1000*(t_inf-t0):.1f}ms total={1000*(time.time()-t0):.1f}ms frames={n}", flush=True)
while not res_q.empty():
try: res_q.get_nowait()
except queue.Empty: break
res_q.put(motion)
except Exception as e:
print(f"[Planner] Error: {e}", flush=True)
# ── SonicPlanner ──────────────────────────────────────────────────────────────
class SonicPlanner:
def __init__(self, session, planner_path):
self.session = session
self.planner_path = planner_path
self.gen_frame = 0
self.random_seed = INITIAL_RANDOM_SEED
self.version = 1 if len(session.get_inputs()) >= 11 else 0
self.motion_50hz = PlannerMotion()
self._snapshot = PlannerMotion()
self._req_q = self._res_q = self._stop_evt = self._planner_thread = None
self._ctrl = None
def _build_inputs(self, ctx, ms):
return _build_planner_inputs(
ctx,
{"mode": ms.mode, "speed": ms.speed, "height": ms.height,
"movement_direction": list(ms.movement_direction),
"facing_direction": list(ms.facing_direction)},
self.version, self.random_seed)
@staticmethod
def build_initial_context(joint_positions):
ctx = np.zeros((4, 36), np.float32)
jp_mj = joint_positions.astype(np.float32)[ISAACLAB_TO_MUJOCO]
for n in range(4):
ctx[n, 2] = DEFAULT_HEIGHT
ctx[n, 3] = 1.0
ctx[n, 7:36] = jp_mj
return ctx
def _context_from_controller(self, current_frame):
ctrl = self._ctrl
gen_frame = current_frame + MOTION_LOOK_AHEAD_STEPS
t_arr = gen_frame / 50.0 + np.arange(4) / 30.0
f50 = t_arr * 50.0
with ctrl.motion_lock:
ts = ctrl.motion_timesteps
if ts <= 0:
return self.build_initial_context(DEFAULT_ANGLES)
bp, bq, jp = ctrl.motion_body_pos, ctrl.motion_body_quats, ctrl.motion_joint_positions
f0 = np.minimum(np.floor(f50).astype(int), ts - 1)
f1 = np.minimum(f0 + 1, ts - 1)
frac = f50 - f0
w0 = 1.0 - frac
ctx = np.zeros((4, 36), np.float32)
ctx[:, 0:3] = w0[:, None] * bp[f0] + frac[:, None] * bp[f1]
ctx[:, 3:7] = quat_slerp_batch(bq[f0], bq[f1], frac)
ij = w0[:, None] * jp[f0] + frac[:, None] * jp[f1]
ctx[:, 7:36] = ij[:, ISAACLAB_TO_MUJOCO]
self.gen_frame = gen_frame
return ctx
def _load_motion_in_place(self, qpos, n30, target=None):
if target is None: target = self.motion_50hz
r = _resample_30_to_50(qpos, n30)
n = r["timesteps"]; target.timesteps = n
target.joint_positions[:n] = r["joint_positions"]
target.joint_velocities[:n] = r["joint_velocities"]
target.body_positions[:n] = r["body_positions"]
target.body_quaternions[:n] = r["body_quaternions"]
return target
def initialize(self, joint_positions, ms):
ctx = self.build_initial_context(joint_positions)
qpos_out, num_pred = self.session.run(None, self._build_inputs(ctx, ms))
n = int(num_pred.flat[0]); qpos = qpos_out[0,:n]
if np.any(np.isnan(qpos)): raise RuntimeError("Planner initial output contains NaN")
print(f"[Planner] Init: {n} frames @ 30 Hz")
self._load_motion_in_place(qpos, n)
print(f"[Planner] Resampled to {self.motion_50hz.timesteps} frames @ 50 Hz")
return self.motion_50hz
def request_replan(self, cursor, ms):
if self._req_q is None: return
ctx = self._context_from_controller(cursor)
ms_dict = {"mode": ms.mode, "speed": ms.speed, "height": ms.height,
"movement_direction": list(ms.movement_direction),
"facing_direction": list(ms.facing_direction)}
while not self._req_q.empty():
try: self._req_q.get_nowait()
except queue.Empty: break
self._req_q.put((ctx, self.gen_frame, ms_dict))
def try_get_new_motion(self):
if self._res_q is None: return None
result = None
while not self._res_q.empty():
try: result = self._res_q.get_nowait()
except queue.Empty: break
if result is None: return None
n, gf = result["timesteps"], result["gen_frame"]
s = self._snapshot; s.timesteps = n
s.joint_positions[:n] = result["joint_positions"]
s.joint_velocities[:n] = result["joint_velocities"]
s.body_positions[:n] = result["body_positions"]
s.body_quaternions[:n] = result["body_quaternions"]
return s, gf
def start_subprocess(self, controller, use_gpu: bool = False):
"""Run planner ONNX in a background thread (avoids mp spawn/fork + CUDA/MuJoCo issues)."""
self._ctrl = controller
self._req_q = queue.Queue()
self._res_q = queue.Queue()
self._stop_evt = threading.Event()
self._planner_thread = threading.Thread(
target=_planner_worker,
args=(self.planner_path, self._req_q, self._res_q,
self._stop_evt, self.version, self.random_seed, use_gpu),
daemon=True,
name="sonic-planner",
)
self._planner_thread.start()
print(f"[Planner] Background thread started ({'GPU' if use_gpu else 'CPU'})")
def stop_subprocess(self):
if self._stop_evt:
self._stop_evt.set()
if self._planner_thread is not None:
self._planner_thread.join(timeout=3.0)
print("[Planner] Background thread stopped")
self._planner_thread = None
self._req_q = self._res_q = self._stop_evt = None
# ── PlannerController ─────────────────────────────────────────────────────────
class PlannerController(StandingEncoderDecoder):
def __init__(self, planner, encoder, decoder):
super().__init__(encoder, decoder)
self.planner = planner
self.ref_cursor = 0
self.motion_timesteps = 0
self.motion_joint_positions = np.zeros((1500,29), np.float64)
self.motion_joint_velocities = np.zeros((1500,29), np.float64)
self.motion_body_quats = np.zeros((1500,4), np.float64); self.motion_body_quats[:,0] = 1.0
self.motion_body_pos = np.zeros((1500,3), np.float64)
self.init_ref_quat = np.array([1,0,0,0], np.float64)
self.heading_init_base_quat = np.array([1,0,0,0], np.float64)
self.delta_heading = 0.0
self.reinit_heading = False
self.playing = self.first_motion = False
self.motion_lock = threading.Lock()
def load_initial_motion(self, motion):
with self.motion_lock:
n = motion.timesteps
self.motion_timesteps = n
self.motion_joint_positions[:n] = motion.joint_positions[:n]
self.motion_joint_velocities[:n] = motion.joint_velocities[:n]
self.motion_body_quats[:n] = motion.body_quaternions[:n]
self.motion_body_pos[:n] = motion.body_positions[:n]
self.init_ref_quat = motion.body_quaternions[0].copy()
self.ref_cursor = 0; self.first_motion = True
self.playing = True; self.delta_heading = 0.0
def blend_new_motion(self, new_motion, gen_frame):
"""Blend like C++ CurrentFrameAdvancement: 8-frame cross-fade, then copy tail."""
with self.motion_lock:
cur = self.ref_cursor
new_len = gen_frame - cur + new_motion.timesteps
if new_len <= 0:
return
if self.motion_timesteps == 0:
n = new_motion.timesteps
self.motion_joint_positions[:n] = new_motion.joint_positions[:n]
self.motion_joint_velocities[:n] = new_motion.joint_velocities[:n]
self.motion_body_pos[:n] = new_motion.body_positions[:n]
self.motion_body_quats[:n] = new_motion.body_quaternions[:n]
self.motion_timesteps = n
self.ref_cursor = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
self.first_motion = False
return
blend_start = max(0, gen_frame - cur)
blend_end = min(new_len, blend_start + BLEND_FRAMES)
for f in range(blend_end):
f_old = min(f + cur, self.motion_timesteps - 1)
f_new = max(0, min(f + cur - gen_frame, new_motion.timesteps - 1))
w_new = min(1.0, max(0.0, (f - blend_start) / BLEND_FRAMES))
w_old = 1.0 - w_new
self.motion_joint_positions[f] = (
w_old * self.motion_joint_positions[f_old]
+ w_new * new_motion.joint_positions[f_new]
)
self.motion_joint_velocities[f] = (
w_old * self.motion_joint_velocities[f_old]
+ w_new * new_motion.joint_velocities[f_new]
)
self.motion_body_pos[f] = (
w_old * self.motion_body_pos[f_old]
+ w_new * new_motion.body_positions[f_new]
)
self.motion_body_quats[f] = quat_slerp(
self.motion_body_quats[f_old], new_motion.body_quaternions[f_new], w_new
)
for f in range(blend_end, new_len):
f_new = max(0, min(f + cur - gen_frame, new_motion.timesteps - 1))
self.motion_joint_positions[f] = new_motion.joint_positions[f_new]
self.motion_joint_velocities[f] = new_motion.joint_velocities[f_new]
self.motion_body_pos[f] = new_motion.body_positions[f_new]
self.motion_body_quats[f] = new_motion.body_quaternions[f_new].copy()
self.motion_timesteps = new_len
self.first_motion = False
self.ref_cursor = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
def _heading_apply_delta(self):
delta = quat_mul(heading_quat(self.heading_init_base_quat).astype(np.float32),
heading_quat_inv(self.init_ref_quat).astype(np.float32))
if self.delta_heading:
h = self.delta_heading / 2.0
delta = quat_mul(np.array([np.cos(h),0,0,np.sin(h)], np.float32), delta)
return delta
def _anchor_6d(self, base_quat, ref_quat=None):
if ref_quat is None: ref_quat = self.init_ref_quat
new_ref = quat_mul(self._heading_apply_delta(), ref_quat.astype(np.float32))
return quat_to_6d(quat_mul(quat_conj(base_quat.astype(np.float32)), new_ref))
def build_encoder_obs(self):
obs = np.zeros(1762, np.float32); obs[0] = float(self.encode_mode)
with self.motion_lock:
for f in range(10):
tf = min(self.ref_cursor + f*5 if self.playing else self.ref_cursor,
self.motion_timesteps - 1)
obs[4+29*f:4+29*(f+1)] = self.motion_joint_positions[tf].astype(np.float32)
if self.playing:
obs[294+29*f:294+29*(f+1)] = self.motion_joint_velocities[tf].astype(np.float32)
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(
self.h_quat[0], self.motion_body_quats[tf].astype(np.float32))
return obs
def step(self, robot_obs, update_encoder, debug=False):
if robot_obs and (self.first_motion or self.reinit_heading):
q = None
if "imu.quat.w" in robot_obs:
q = np.array([
robot_obs["imu.quat.w"], robot_obs["imu.quat.x"],
robot_obs["imu.quat.y"], robot_obs["imu.quat.z"],
], np.float64)
else:
q = robot_obs.get("imu.quaternion")
if q is not None:
q = np.array(q, np.float64)
if q is not None:
self.heading_init_base_quat = np.array(q, np.float64)
with self.motion_lock:
rf = min(self.ref_cursor, self.motion_timesteps - 1)
self.init_ref_quat = self.motion_body_quats[rf].copy()
self.delta_heading = 0.0
self.first_motion = False
self.reinit_heading = False
print(f"[Heading] init quat: {self.heading_init_base_quat}")
return super().step(robot_obs, update_encoder=update_encoder, debug=debug)
def advance_cursor(self):
"""Advance one frame per 50 Hz tick (C++ current_frame_ += 1), no wall-clock catch-up."""
if not self.playing:
return
with self.motion_lock:
if self.motion_timesteps > 0:
self.ref_cursor = min(self.ref_cursor + 1, self.motion_timesteps - 1)
# ── Keyboard ──────────────────────────────────────────────────────────────────
class RawKeyboard:
def __init__(self):
self.fd = sys.stdin.fileno()
self.old = termios.tcgetattr(self.fd)
def __enter__(self): tty.setcbreak(self.fd); return self
def __exit__(self, *_): termios.tcsetattr(self.fd, termios.TCSADRAIN, self.old)
def get_key(self):
return sys.stdin.read(1) if select.select([sys.stdin],[],[],0)[0] else None
def drain_keyboard(kb, ms, controller=None) -> bool:
"""Process all pending terminal keys this frame (return True to quit)."""
quit_requested = False
while True:
key = kb.get_key()
if key is None:
break
if process_keyboard(key, ms, controller):
quit_requested = True
return quit_requested
def process_keyboard(key, ms, controller=None):
if key is None: return False
if key == '\x1b': return True
if key == ' ':
ms.mode = LM.IDLE; ms.speed = ms.height = -1.0
ms.has_movement = False; ms.needs_replan = True
if controller: controller.playing = False; controller.reinit_heading = True
print("\n >> EMERGENCY STOP -> IDLE"); return False
if key in ('r','R'):
ms.needs_replan = True; print("\n >> Manual replan"); return False
if key in ('n','N','p','P'):
ms.motion_set_idx = (ms.motion_set_idx + (1 if key in ('n','N') else -1)) % len(MOTION_SETS)
name, modes = MOTION_SETS[ms.motion_set_idx]
print(f"\n >> Motion set: {name}")
[print(f" {i+1}: {m.name}") for i,m in enumerate(modes)]
return False
if key.isdigit() and key not in ('9','0'):
idx = int(key) - 1; modes = MOTION_SETS[ms.motion_set_idx][1]
if 0 <= idx < len(modes):
ms.mode = modes[idx]; ms.needs_replan = True
if controller: controller.playing = True; controller.reinit_heading = True
print(f"\n >> Mode: {LM(ms.mode).name} ({ms.mode}) [replanning...]")
return False
if key == '9':
ms.speed = max(0.0, (ms.speed if ms.speed>=0 else 1.0) - 0.1)
print(f"\n >> Speed: {ms.speed:.1f}"); return False
if key == '0':
ms.speed = min(5.0, (ms.speed if ms.speed>=0 else 1.0) + 0.1)
print(f"\n >> Speed: {ms.speed:.1f}"); return False
if key == '-':
ms.height = max(0.2, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) - 0.02)
print(f"\n >> Height: {ms.height:.2f}"); return False
if key == '=':
ms.height = min(1.0, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) + 0.02)
print(f"\n >> Height: {ms.height:.2f}"); return False
if key.lower() == 'w': ms.movement_angle = ms.facing_angle
elif key.lower() == 's': ms.movement_angle = ms.facing_angle + math.pi
elif key.lower() == 'a': ms.movement_angle = ms.facing_angle + math.pi/2
elif key.lower() == 'd': ms.movement_angle = ms.facing_angle - math.pi/2
if key.lower() in ('w','s','a','d'):
ms.has_movement = ms.needs_replan = True
if controller:
controller.playing = True
print(f"\n >> Move {key.upper()} (replanning...)")
elif key.lower() == 'q':
ms.facing_angle += 0.1
if controller: controller.delta_heading += 0.1
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
elif key.lower() == 'e':
ms.facing_angle -= 0.1
if controller: controller.delta_heading -= 0.1
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
return False
_joy_prev_active = False
def _parse_wireless(wr):
"""Parse wireless_remote (bytes or int-array) into (lx, ly, rx, ry)."""
import struct as _st
if not isinstance(wr, (bytes, bytearray)):
wr = bytes(wr)
if len(wr) < 24:
return None
lx = _st.unpack("f", wr[4:8])[0]
rx = _st.unpack("f", wr[8:12])[0]
ry = _st.unpack("f", wr[12:16])[0]
ly = _st.unpack("f", wr[20:24])[0]
return lx, ly, rx, ry
def process_joystick(obs, ms, controller=None):
"""Joystick mirrors keyboard: left stick=WASD, right stick X=Q/E, right stick Y=height."""
global _joy_prev_active
wr = obs.get("wireless_remote")
if wr is None:
return
parsed = _parse_wireless(wr)
if parsed is None:
return
lx, ly, rx, ry = parsed
# Dead zone + negate both Y axes (bridge already flips them once)
lx = 0.0 if abs(lx) < DEADZONE else lx
ly = 0.0 if abs(ly) < DEADZONE else -ly
rx = 0.0 if abs(rx) < DEADZONE else rx
ry = 0.0 if abs(ry) < DEADZONE else -ry
left_active = abs(lx) > 0 or abs(ly) > 0
# Left stick → WASD (movement direction relative to facing)
if left_active:
ms.movement_angle = ms.facing_angle + math.atan2(-lx, -ly)
ms.has_movement = True
if not _joy_prev_active:
ms.needs_replan = True
_joy_prev_active = True
elif _joy_prev_active and not (abs(rx) > 0 or abs(ry) > 0):
_joy_prev_active = False
ms.has_movement = False
# Right stick X → Q/E (facing rotation, ~1 rad/s at full deflection)
if abs(rx) > 0:
delta = -0.02 * rx
ms.facing_angle += delta
if controller:
controller.delta_heading += delta
# Right stick Y → -/= (height adjustment, ~0.25/s at full deflection)
if abs(ry) > 0:
step = -0.005 * ry
ms.height = max(0.1, min(1.0, (ms.height if ms.height >= 0 else DEFAULT_HEIGHT) + step))
@@ -0,0 +1,152 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SONIC full-body controller for Unitree G1."""
import logging
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.controllers.sonic_pipeline import (
CONTROL_DT,
DEBUG_PRINT_EVERY,
DEFAULT_ANGLES,
ENCODER_UPDATE_EVERY,
LM,
MOTION_SETS,
MovementState,
PlannerController,
SonicPlanner,
clamp_mode_params,
compute_kp_kd,
lowstate_to_obs,
process_joystick,
should_replan_request,
_ort_providers,
_snapshot_ms,
)
logger = logging.getLogger(__name__)
class SonicRuntime:
"""Shared SONIC control loop state (standalone demo + locomotion controller)."""
def __init__(self, force_cpu: bool = False):
planner_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="planner_sonic.onnx")
encoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_encoder.onnx")
decoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_decoder.onnx")
providers = _ort_providers(force_cpu=force_cpu)
self.use_gpu = providers[0] == "CUDAExecutionProvider"
so = ort.SessionOptions()
so.log_severity_level = 3
planner_sess = ort.InferenceSession(planner_path, sess_options=so, providers=providers)
encoder_sess = ort.InferenceSession(encoder_path, sess_options=so, providers=providers)
decoder_sess = ort.InferenceSession(decoder_path, sess_options=so, providers=providers)
self.kp, self.kd = compute_kp_kd()
self.ms = MovementState()
self.planner = SonicPlanner(planner_sess, planner_path)
self.controller = PlannerController(self.planner, encoder_sess, decoder_sess)
motion = self.planner.initialize(DEFAULT_ANGLES, self.ms)
self.controller.load_initial_motion(motion)
self.planner.start_subprocess(self.controller, use_gpu=self.use_gpu)
self.step = 0
self.replan_timer = 0.0
self.last_ms = _snapshot_ms(self.ms)
@property
def pipeline(self):
return self.controller
def tick(self, obs: dict, *, debug: bool | None = None, use_joystick: bool = True) -> dict:
if not obs:
self.step += 1
return {}
if use_joystick:
process_joystick(obs, self.ms, self.controller)
clamp_mode_params(self.ms)
if self.step > 0:
self.replan_timer += CONTROL_DT
if should_replan_request(self.ms, self.last_ms, self.replan_timer, self.step):
self.planner.request_replan(self.controller.ref_cursor, self.ms)
self.replan_timer = 0.0
self.ms.needs_replan = False
self.last_ms = _snapshot_ms(self.ms)
do_enc = self.step % ENCODER_UPDATE_EVERY == 0
if debug is None:
debug = self.step % DEBUG_PRINT_EVERY == 0
action = self.controller.step(obs, update_encoder=do_enc, debug=debug)
result = self.planner.try_get_new_motion()
if result:
self.controller.blend_new_motion(*result)
self.controller.advance_cursor()
self.step += 1
return action
def reset(self):
self.ms = MovementState()
self.controller.reinit_heading = True
self.controller.playing = True
self.step = 0
self.replan_timer = 0.0
self.last_ms = _snapshot_ms(self.ms)
def shutdown(self):
self.planner.stop_subprocess()
class SonicWholeBodyController:
"""Full-body SONIC controller for UnitreeG1's background controller thread."""
control_dt = CONTROL_DT
full_body = True
def __init__(self, force_cpu: bool = False):
logger.info("Loading SONIC whole-body controller...")
self._runtime = SonicRuntime(force_cpu=force_cpu)
self.kp = self._runtime.kp
self.kd = self._runtime.kd
self.controller = self._runtime.controller
self.ms = self._runtime.ms
logger.info(
"SONIC ready: %s (default mode: %s)",
MOTION_SETS[0][0],
LM(self.ms.mode).name,
)
def run_step(self, action: dict, lowstate) -> dict:
if lowstate is None:
return {}
obs = lowstate_to_obs(lowstate)
return self._runtime.tick(obs, debug=False)
def reset(self):
self._runtime.reset()
def shutdown(self):
self._runtime.shutdown()
+3 -2
View File
@@ -68,8 +68,9 @@ def make_locomotion_controller(name: str | None):
if name is None:
return None
controllers = {
"GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion",
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion",
"GrootLocomotionController": "lerobot.robots.unitree_g1.controllers.gr00t_locomotion",
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.controllers.holosoma_locomotion",
"SonicWholeBodyController": "lerobot.robots.unitree_g1.controllers.sonic_whole_body",
}
module_path = controllers.get(name)
if module_path is None:
+9 -1
View File
@@ -338,6 +338,9 @@ class UnitreeG1(Robot):
self.kp = np.array(self.config.kp, dtype=np.float32)
self.kd = np.array(self.config.kd, dtype=np.float32)
if self.controller is not None and hasattr(self.controller, "kp"):
self.kp = np.array(self.controller.kp, dtype=np.float32)
self.kd = np.array(self.controller.kd, dtype=np.float32)
for joint in G1_29_JointIndex:
self.msg.motor_cmd[joint].mode = 1
@@ -374,6 +377,9 @@ class UnitreeG1(Robot):
# Signal thread to stop and unblock any waits
self._shutdown_event.set()
if self.controller is not None and hasattr(self.controller, "shutdown"):
self.controller.shutdown()
# Wait for subscribe thread to finish
if self.subscribe_thread is not None:
self.subscribe_thread.join(timeout=2.0)
@@ -465,9 +471,11 @@ class UnitreeG1(Robot):
def send_action(self, action: RobotAction) -> RobotAction:
action_to_publish = action
if self.controller is not None:
self._update_controller_action(action)
if getattr(self.controller, "full_body", False):
return action
# Controller thread owns legs/waist. Here we only update joystick inputs
# and publish arm targets from the teleoperator.
self._update_controller_action(action)
arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex)
action_to_publish = {
key: value
+1
View File
@@ -54,6 +54,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
@@ -57,6 +57,7 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
+1
View File
@@ -137,6 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
+1
View File
@@ -174,6 +174,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
@@ -41,6 +41,7 @@ from lerobot.robots import ( # noqa: F401
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
koch_leader,
@@ -89,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
+6 -56
View File
@@ -45,8 +45,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state
from lerobot.datasets.factory import make_train_eval_datasets
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -245,19 +244,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
if is_main_process:
logging.info("Creating dataset")
dataset, eval_dataset = make_train_eval_datasets(cfg)
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Other ranks read from the shared copy populated by the main process.
if not is_main_process:
dataset, eval_dataset = make_train_eval_datasets(cfg)
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.env_eval_freq > 0 and cfg.env is not None and is_main_process:
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
@@ -456,33 +455,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
# Build eval dataloader if a held-out split exists
eval_dataloader = None
if eval_dataset is not None:
eval_ds = eval_dataset
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
unique_tasks = sorted(set(task_arr.tolist()))
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
selected: list[int] = []
for t in unique_tasks:
frames = (task_arr == t).nonzero()[0][:per_task]
selected.extend(frames.tolist())
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
eval_dataloader = torch.utils.data.DataLoader(
eval_ds,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=eval_collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
@@ -562,8 +534,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_env_eval_step = cfg.env_eval_freq > 0 and step % cfg.env_eval_freq == 0
is_eval_step = cfg.eval_steps > 0 and eval_dataloader is not None and step % cfg.eval_steps == 0
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
# Collective reduce must run on every rank, before the main-process gate below.
@@ -586,27 +557,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if is_eval_step:
policy.eval()
eval_loss_sum = 0.0
n_eval_batches = 0
with torch.no_grad(), accelerator.autocast():
for eval_batch in eval_dataloader:
for cam_key in dataset.meta.camera_keys:
if cam_key in eval_batch and eval_batch[cam_key].dtype == torch.uint8:
eval_batch[cam_key] = eval_batch[cam_key].to(dtype=torch.float32) / 255.0
eval_batch = preprocessor(eval_batch)
loss, _ = policy.forward(eval_batch)
eval_loss_sum += loss.item()
n_eval_batches += 1
eval_loss = eval_loss_sum / max(n_eval_batches, 1)
policy.train()
if is_main_process:
logging.info(f"step {step}: eval_loss={eval_loss:.4f}")
if wandb_logger:
wandb_logger.log_dict({"eval_loss": eval_loss}, step=step, mode="eval")
if cfg.save_checkpoint and is_saving_step:
if is_main_process:
logging.info(f"Checkpoint policy after step {step}")
@@ -629,7 +579,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
accelerator.wait_for_everyone()
if cfg.env and is_env_eval_step:
if cfg.env and is_eval_step:
if is_main_process:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
@@ -18,7 +18,8 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
from ..teleoperator import Teleoperator
@@ -27,7 +28,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
logger = logging.getLogger(__name__)
class BiOpenArmLeader(Teleoperator):
class BiOpenArmLeader(BimanualMixin, Teleoperator):
"""
Bimanual OpenArm Leader Arms
"""
@@ -86,27 +87,6 @@ class BiOpenArmLeader(Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -129,8 +109,3 @@ class BiOpenArmLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
@dataclass
class BiOpenArmLeaderConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Follower robots."""
"""Configuration class for Bi OpenArm Leader teleoperators."""
left_arm_config: OpenArmLeaderConfigBase
right_arm_config: OpenArmLeaderConfigBase
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_mini import BiOpenArmMini
from .config_bi_openarm_mini import BiOpenArmMiniConfig
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
@@ -0,0 +1,101 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
from ..teleoperator import Teleoperator
from .config_bi_openarm_mini import BiOpenArmMiniConfig
logger = logging.getLogger(__name__)
class BiOpenArmMini(BimanualMixin, Teleoperator):
"""Bimanual OpenArm Mini teleoperator.
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
bimanual leader can teleoperate a bimanual OpenArm follower.
"""
config_class = BiOpenArmMiniConfig
name = "bi_openarm_mini"
def __init__(self, config: BiOpenArmMiniConfig):
super().__init__(config)
self.config = config
# `side` is forced to match left/right regardless of what the user passed
# on the per-arm base config — the bimanual wrapper owns the side semantics.
left_arm_config = OpenArmMiniConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
side="left",
use_degrees=config.left_arm_config.use_degrees,
)
right_arm_config = OpenArmMiniConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
side="right",
use_degrees=config.right_arm_config.use_degrees,
)
self.left_arm = OpenArmMini(left_arm_config)
self.right_arm = OpenArmMini(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
}
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
action: RobotAction = {}
for k, v in self.left_arm.get_action().items():
action[f"left_{k}"] = v
for k, v in self.right_arm.get_action().items():
action[f"right_{k}"] = v
return action
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
if left_fb:
self.left_arm.send_feedback(left_fb)
if right_fb:
self.right_arm.send_feedback(right_fb)
@@ -0,0 +1,29 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
from ..openarm_mini import OpenArmMiniConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
@dataclass
class BiOpenArmMiniConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Mini teleoperators."""
left_arm_config: OpenArmMiniConfigBase
right_arm_config: OpenArmMiniConfigBase
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_rebot_102_leader import BiRebotArm102Leader
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
from .bi_rebot_102_leader import BiRebot102Leader
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
@@ -18,16 +18,17 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
from ..teleoperator import Teleoperator
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
logger = logging.getLogger(__name__)
class BiRebotArm102Leader(Teleoperator):
class BiRebot102Leader(BimanualMixin, Teleoperator):
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
@@ -35,10 +36,10 @@ class BiRebotArm102Leader(Teleoperator):
leader can teleoperate a bimanual reBot B601 follower.
"""
config_class = BiRebotArm102LeaderConfig
config_class = BiRebot102LeaderConfig
name = "bi_rebot_102_leader"
def __init__(self, config: BiRebotArm102LeaderConfig):
def __init__(self, config: BiRebot102LeaderConfig):
super().__init__(config)
self.config = config
@@ -76,27 +77,6 @@ class BiRebotArm102Leader(Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_action(self) -> RobotAction:
action_dict = {}
@@ -106,8 +86,3 @@ class BiRebotArm102Leader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
@dataclass
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
class BiRebot102LeaderConfig(TeleoperatorConfig):
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
left_arm_config: RebotArm102LeaderConfig
@@ -17,7 +17,9 @@
import logging
from functools import cached_property
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..so_leader import SOLeader, SOLeaderTeleopConfig
from ..teleoperator import Teleoperator
@@ -26,7 +28,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
logger = logging.getLogger(__name__)
class BiSOLeader(Teleoperator):
class BiSOLeader(BimanualMixin, Teleoperator):
"""
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -67,33 +69,12 @@ class BiSOLeader(Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> dict[str, float]:
def get_action(self) -> RobotAction:
action_dict = {}
# Add "left_" prefix
@@ -109,8 +90,3 @@ class BiSOLeader(Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
from .openarm_mini import OpenArmMini
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
@@ -19,12 +19,21 @@ from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig):
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
class OpenArmMiniConfigBase:
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
port_right: str = "/dev/ttyUSB0"
port_left: str = "/dev/ttyUSB1"
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
port: str
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
# during readout. If `None`, no flipping is applied.
side: str | None = None
use_degrees: bool = True
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
pass
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
# Per-side motor direction flips applied during readout.
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
}
# Leader joint 6 maps to follower joint 7 and vice versa
# Leader joint 6 follower joint 7 (symmetric — its own inverse).
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator):
"""
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
"""
config_class = OpenArmMiniConfig
@@ -56,9 +56,12 @@ class OpenArmMini(Teleoperator):
super().__init__(config)
self.config = config
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
norm_mode_body = MotorNormMode.DEGREES
motors_right = {
motors = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
@@ -69,46 +72,15 @@ class OpenArmMini(Teleoperator):
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
motors_left = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
self.bus = FeetechMotorsBus(
port=self.config.port,
motors=motors,
calibration=self.calibration,
)
@property
def action_features(self) -> dict[str, type]:
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
return {f"{motor}.pos": float for motor in self.bus.motors}
@property
def feedback_features(self) -> dict[str, type]:
@@ -116,14 +88,12 @@ class OpenArmMini(Teleoperator):
@property
def is_connected(self) -> bool:
return self.bus_right.is_connected and self.bus_left.is_connected
return self.bus.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
logger.info(f"Connecting arm on {self.config.port}...")
self.bus.connect()
if calibrate:
self.calibrate()
@@ -133,14 +103,14 @@ class OpenArmMini(Teleoperator):
@property
def is_calibrated(self) -> bool:
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
return self.bus.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArm Mini.
Run calibration procedure for a single OpenArm Mini arm.
1. Disable torque
2. Ask user to position arms in hanging position with grippers closed
2. Ask user to position arm in hanging position with gripper closed
3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions)
5. Save calibration
@@ -152,70 +122,51 @@ class OpenArmMini(Teleoperator):
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
cal_right = {
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
self.bus.write_calibration(self.calibration)
return
logger.info(f"\nRunning calibration for {self}")
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
self.bus.disable_torque()
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
logger.info("Setting Phase to 12 for all motors...")
for motor in self.bus.motors:
self.bus.write("Phase", motor, 12)
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
"""Calibrate a single arm with Feetech motors."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input(
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"\nCalibration: Zero Position\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
homing_offsets = bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
homing_offsets = self.bus.set_half_turn_homings()
logger.info("Arm zero position set.")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
print("\nSetting motor ranges\n")
if self.calibration is None:
self.calibration = {}
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
max_res = motor_resolution - 1
for motor_name, motor in bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
for motor_name, motor in self.bus.motors.items():
if motor_name == "gripper":
input(
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
f"Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..."
"\nGripper Calibration\n"
"Step 1: CLOSE the gripper fully\n"
"Press ENTER when gripper is closed..."
)
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = bus.read("Present_Position", motor_name, normalize=False)
open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos:
@@ -228,16 +179,16 @@ class OpenArmMini(Teleoperator):
drive_mode = 1
logger.info(
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
f" {motor_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})"
)
else:
range_min = 0
range_max = max_res
drive_mode = 0
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[prefixed_name] = MotorCalibration(
self.calibration[motor_name] = MotorCalibration(
id=motor.id,
drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name],
@@ -245,108 +196,68 @@ class OpenArmMini(Teleoperator):
range_max=range_max,
)
cal_for_bus = {
k.replace(f"{arm_name}_", ""): v
for k, v in self.calibration.items()
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
self.bus.write_calibration(self.calibration)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
def configure(self) -> None:
self.bus_right.disable_torque()
self.bus_right.configure_motors()
for motor in self.bus_right.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus.disable_torque()
self.bus.configure_motors()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None:
print("\nSetting up RIGHT arm motors...")
for motor in reversed(self.bus_right.motors):
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
self.bus_right.setup_motor(motor)
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> RobotAction:
"""Get current action from both arms (read positions from all motors)."""
"""Get current action (read positions from all motors)."""
start = time.perf_counter()
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
positions = self.bus.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
# Per-side direction flip is applied based on the configured `side`.
action: dict[str, Any] = {}
for motor, val in right_positions.items():
for motor, val in positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def enable_torque(self) -> None:
"""Enable torque on both arms for position control."""
self.bus_right.enable_torque()
self.bus_left.enable_torque()
self.bus.enable_torque()
def disable_torque(self) -> None:
"""Disable torque on both arms for free movement."""
self.bus_right.disable_torque()
self.bus_left.disable_torque()
self.bus.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
right_goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
goals: dict[str, float] = {}
for key, val in positions.items():
if not key.endswith(".pos"):
continue
motor_name = key.removesuffix(".pos")
if motor_name.startswith("right_"):
base = motor_name.removeprefix("right_")
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
base = key.removesuffix(".pos")
# JOINT_REMAP is symmetric (its own inverse).
target = JOINT_REMAP.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
goals[target] = -val if target in self._motors_to_flip else val
if right_goals:
self.bus_right.sync_write("Goal_Position", right_goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
if goals:
self.bus.sync_write("Goal_Position", goals)
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
@@ -354,6 +265,5 @@ class OpenArmMini(Teleoperator):
@check_if_not_connected
def disconnect(self) -> None:
self.bus_right.disconnect()
self.bus_left.disconnect()
self.bus.disconnect()
logger.info(f"{self} disconnected.")
+6 -2
View File
@@ -99,14 +99,18 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .openarm_mini import OpenArmMini
return OpenArmMini(config)
elif config.type == "bi_openarm_mini":
from .bi_openarm_mini import BiOpenArmMini
return BiOpenArmMini(config)
elif config.type == "rebot_102_leader":
from .rebot_102_leader import RebotArm102Leader
return RebotArm102Leader(config)
elif config.type == "bi_rebot_102_leader":
from .bi_rebot_102_leader import BiRebotArm102Leader
from .bi_rebot_102_leader import BiRebot102Leader
return BiRebotArm102Leader(config)
return BiRebot102Leader(config)
else:
try:
return cast("Teleoperator", make_device_from_device_class(config))
+63
View File
@@ -0,0 +1,63 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
class BimanualMixin:
"""Lifecycle delegation for bimanual robots and teleoperators.
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
their own ``__init__``. They retain ownership of feature dicts and the
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
``send_feedback``), which vary per-embodiment.
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
take precedence in the MRO::
class BiFooFollower(BimanualMixin, Robot): ...
"""
left_arm: Any
right_arm: Any
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
+73
View File
@@ -28,6 +28,7 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
@@ -344,6 +345,78 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
assert len(dataset) == 24
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
for ep, length in enumerate(episode_lengths):
episode_index += [ep] * length
frame_index += list(range(length))
timestamp += [round(i / fps, 6) for i in range(length)]
task_index += [0] * length
subtask_index += [0] * length # legacy column the writer must drop
pd.DataFrame(
{
"episode_index": episode_index,
"frame_index": frame_index,
"timestamp": timestamp,
"task_index": task_index,
"subtask_index": subtask_index,
}
).to_parquet(data_dir / "file-000.parquet", index=False)
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
write_tasks(tasks_df, root)
write_json(
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
root / "meta" / "info.json",
)
return root
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
"""Rewriting a packed shard must keep one row group per episode, not collapse
every episode into a single giant row group."""
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
shard = root / "data" / "chunk-000" / "file-000.parquet"
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
staging_dir = tmp_path / "stage"
for ep in range(len(episode_lengths)):
_stage_episode(
staging_dir,
ep,
plan=[
{
"role": "assistant",
"content": f"subtask for ep {ep}",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
)
records = list(iter_episodes(root))
LanguageColumnsWriter().write_all(records, staging_dir, root)
# One row group per episode, with row counts matching the episode lengths.
md = pq.ParquetFile(shard).metadata
assert md.num_row_groups == len(episode_lengths)
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
# Language columns are still present after the per-episode rewrite.
table = pq.read_table(shard)
assert "language_persistent" in table.column_names
assert "language_events" in table.column_names
def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant"
+55
View File
@@ -32,6 +32,26 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def assert_data_shards_one_row_group_per_episode(root):
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
import pyarrow.parquet as pq
shards = sorted((root / "data").rglob("*.parquet"))
assert shards, f"no data shards found under {root}/data"
n_episodes = 0
for shard in shards:
pf = pq.ParquetFile(shard)
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
assert pf.metadata.num_row_groups == len(set(episodes)), shard
for i in range(pf.metadata.num_row_groups):
rg_episodes = set(
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
)
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
n_episodes += len(set(episodes))
return n_episodes
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, (
@@ -566,6 +586,41 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
)
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
Covers both the non-image (``df.to_parquet``) and image
(``to_parquet_with_hf_images``) write branches, including the merge-into-
existing-file branch via a low file-size threshold that forces packing.
"""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "rg_0",
repo_id=f"{DUMMY_REPO_ID}_rg_0",
total_episodes=3,
total_frames=60,
use_videos=use_videos,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "rg_1",
repo_id=f"{DUMMY_REPO_ID}_rg_1",
total_episodes=4,
total_frames=80,
use_videos=use_videos,
)
aggr_root = tmp_path / "rg_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
aggr_root=aggr_root,
)
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
+21 -3
View File
@@ -2370,14 +2370,32 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # expect "image" dtype
use_videos=False, # images kept, stored as "image" dtype
patterns=None,
)
key = f"{OBS_IMAGES}.back"
key_front = f"{OBS_IMAGES}.front"
assert key not in out
assert key_front not in out
assert key in out
assert key_front in out
assert out[key]["dtype"] == "image"
assert out[key_front]["dtype"] == "image"
assert out[key]["shape"] == initial["back"]
def test_aggregate_images_excluded():
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
initial = {"back": (480, 640, 3)}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
exclude_images=True,
patterns=None,
)
assert f"{OBS_IMAGES}.back" not in out
assert f"{OBS_IMAGES}.front" not in out
def test_aggregate_images_when_use_videos_true():
+3 -3
View File
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
import pytest
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
from lerobot.teleoperators.rebot_102_leader import (
RebotArm102Leader,
RebotArm102LeaderConfig,
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebotArm102LeaderConfig(
cfg = BiRebot102LeaderConfig(
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
)
teleop = BiRebotArm102Leader(cfg)
teleop = BiRebot102Leader(cfg)
assert any(k.startswith("left_") for k in teleop.action_features)
assert any(k.startswith("right_") for k in teleop.action_features)
assert "left_gripper.pos" in teleop.action_features
+2 -2
View File
@@ -134,7 +134,7 @@ class TestMultiGPUTraining:
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=10",
"--env_eval_freq=-1",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
@@ -177,7 +177,7 @@ class TestMultiGPUTraining:
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=20",
"--env_eval_freq=-1",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
Generated
+949 -900
View File
File diff suppressed because it is too large Load Diff