Compare commits

..

2 Commits

Author SHA1 Message Date
Khalil Meftah 6407a244c0 feat(envs): add generic observation passthrough
- Add generic observation passthrough in preprocess_observation() for
unhandled ndarray/tensor keys, replacing the pattern of adding per-env
hardcoded key handlers. Extra keys are forwarded as observation.<key>
and can be shaped by env-specific ProcessorSteps via get_env_processors().
2026-06-15 14:17:59 +02:00
Khalil Meftah 0511c12b8f feat(envs): add env plugin discovery
- Add 'lerobot_env_' to third-party plugin discovery prefixes, completing
the plugin system for all component types (robots, cameras, teleoperators,
policies, and now environments). External packages named lerobot_env_*
can self-register EnvConfig subclasses on import, enabling --env.type=
resolution without lerobot code changes.
2026-06-15 14:13:12 +02:00
48 changed files with 1384 additions and 2921 deletions
+8 -8
View File
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
**Compatible teleoperators:**
- `bi_openarm_mini` - Bimanual OpenArm Mini
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \
+1 -1
View File
@@ -117,7 +117,7 @@ lerobot-rollout \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=bi_openarm_mini \
--teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt"
```
-217
View File
@@ -1,217 +0,0 @@
#!/usr/bin/env python
"""
SONIC planner with full mode control.
Keyboard controls:
N / P - next / previous motion set
1-8 - select mode within current set
WASD - movement direction
Q / E - rotate facing left / right
9 / 0 - decrease / increase speed
- / = - decrease / increase height
R - force replan
Space - emergency stop -> IDLE
Esc - quit
Gamepad controls (Unitree wireless controller):
Left stick Y - speed (forward = fast, back = stop)
Left stick X - movement direction (offset from facing)
Right stick X - facing direction (incremental rotation)
Right stick Y - height (up = tall 0.8m, down = low 0.1m)
Buttons - unused (mode selection is keyboard-only)
For teleop integration use --robot.controller=SonicWholeBodyController instead.
"""
import argparse
import gc
import time
import numpy as np
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.controllers.sonic_whole_body import SonicRuntime
from lerobot.robots.unitree_g1.controllers.sonic_pipeline import (
CONTROL_DT,
DEFAULT_ANGLES,
LM,
MOTION_SETS,
RawKeyboard,
compute_kp_kd,
drain_keyboard,
)
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
def main():
parser = argparse.ArgumentParser(description="SONIC planner with keyboard + gamepad control")
parser.add_argument("--ip", type=str, default=None,
help="Robot IP for real hardware (e.g. 192.168.123.164). "
"Omit for simulation.")
parser.add_argument("--log-csv", action="store_true",
help="Write /tmp/sonic_pose_log.csv (disabled by default for teleop perf)")
parser.add_argument("--cpu", action="store_true",
help="Force CPU ONNX Runtime (skip CUDA even if onnxruntime-gpu is installed)")
parser.add_argument("--headless", action="store_true",
help="Ignored for sim (stock UnitreeG1 uses hub MuJoCo defaults)")
parser.add_argument("--gamepad", action="store_true",
help="Read Unitree wireless gamepad in sim (default: keyboard-only in sim)")
parser.add_argument("--keyboard-only", action="store_true",
help="Ignore wireless gamepad (terminal keyboard only)")
args = parser.parse_args()
print("=" * 60)
print("SONIC planner - full mode control")
print(" N/P cycle sets | 1-8 select mode | WASD move")
print(" Q/E rotate | 9/0 speed | -/= height")
print(" R replan | Space IDLE | Esc quit")
if args.ip:
print(f" Robot IP: {args.ip}")
else:
print(" Mode: simulation")
print("=" * 60 + "\n")
cfg = UnitreeG1Config(controller=None) # full-body SONIC; standalone loop owns publish
if args.ip:
cfg.is_simulation = False
cfg.robot_ip = args.ip
else:
cfg.is_simulation = True
if args.headless:
print("[Note] --headless ignored: sim uses stock UnitreeG1 + hub env")
robot = UnitreeG1(cfg)
robot.connect()
kp, kd = compute_kp_kd()
robot.kp = kp.copy()
robot.kd = kd.copy()
runtime = SonicRuntime(force_cpu=args.cpu)
controller = runtime.controller
ms = runtime.ms
runtime.controller.print_input_diagnostics()
print(f"\nStarting: {MOTION_SETS[0][0]} (default mode: {LM(ms.mode).name})")
[print(f" {i+1}: {m.name}") for i, m in enumerate(MOTION_SETS[0][1])]
print("\n[Ready] Click THIS terminal, then W/A/S/D to move. "
"1-6 change mode, 9/0 speed, Esc quit.\n", flush=True)
# Sim hub publishes wireless_remote bytes that can fight terminal WASD.
use_joystick = not args.keyboard_only and (args.gamepad or args.ip is not None)
with RawKeyboard() as kb:
try:
gc.disable()
gc_timer = 0.0
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
time.sleep(1.0)
last_status = time.time() - 2.1
loop_t = enc_t = dec_t = obs_t = act_t = []
slow_n = blend_n = 0
stall_src = ""
did_blend = False
prev_end = time.time()
t_start = time.time()
log_path = "/tmp/sonic_pose_log.csv"
jnames = [m.name for m in G1_29_JointIndex]
log_ctx = open(log_path, "w") if args.log_csv else None
if log_ctx:
log_ctx.write("t,step,cursor,ts,blend,mode," +
",".join(f"q{i}" for i in range(29)) + "," +
",".join(f"ref{i}" for i in range(29)) + "," +
",".join(f"act{i}" for i in range(29)) +
",delta_max,action_norm,token_norm\n")
try:
while not robot._shutdown_event.is_set():
t0 = time.time()
if drain_keyboard(kb, ms, controller):
break
obs = robot.get_observation()
t_obs = time.time()
obs_t.append(1000 * (t_obs - t0))
if not obs:
runtime.tick({}, use_joystick=False)
time.sleep(max(0.0, CONTROL_DT - (time.time() - t0)))
continue
step_before = runtime.step
t_step = time.time()
action = runtime.tick(obs, use_joystick=use_joystick)
step_ms = 1000 * (time.time() - t_step)
do_enc = step_before % 5 == 0
(enc_t if do_enc else dec_t).append(step_ms)
t_act = time.time()
robot.send_action(action)
act_t.append(1000 * (time.time() - t_act))
if log_ctx and runtime.step % 5 == 0:
t_rel = time.time() - t_start
q_r = np.array([obs.get(f"{n}.q", 0) for n in jnames])
a_v = np.array([action.get(f"{n}.q", 0) for n in jnames])
cur, ts = controller.ref_cursor, controller.motion_timesteps
q_ref = controller.motion_joint_positions[min(cur, ts - 1)] if ts > 0 else np.zeros(29)
log_ctx.write(f"{t_rel:.4f},{runtime.step},{cur},{ts},{int(did_blend)},{ms.mode}," +
",".join(f"{v:.6f}" for v in q_r) + "," +
",".join(f"{v:.6f}" for v in q_ref) + "," +
",".join(f"{v:.6f}" for v in a_v) + "," +
f"{np.max(np.abs(a_v - q_r)):.6f},"
f"{np.linalg.norm(a_v):.6f},"
f"{np.linalg.norm(controller.token):.6f}\n")
did_blend = False
now = time.time()
loop_ms = 1000 * (now - t0)
if loop_ms > 50:
stall_src = (f"[STALL] {loop_ms:.0f}ms: "
f"obs={obs_t[-1]:.0f} step={step_ms:.0f} act={act_t[-1]:.0f}")
if loop_ms > CONTROL_DT * 1500:
slow_n += 1
if now - last_status > 2.0:
def _avg(lst):
return sum(lst) / len(lst) if lst else 0
hz = 1000 / _avg(loop_t) if _avg(loop_t) else 0
print(f"\r {ms.status_line()} step={runtime.step} "
f"ref={controller.ref_cursor}/{controller.motion_timesteps} "
f"loop={_avg(loop_t):.1f}ms(max={max(loop_t, default=0):.1f}) hz={hz:.0f} "
f"enc={_avg(enc_t):.1f} dec={_avg(dec_t):.1f} obs={_avg(obs_t):.1f} "
f"slow={slow_n} blends={blend_n}", end="", flush=True)
if stall_src:
print(f"\n {stall_src}")
stall_src = ""
last_status = now
loop_t = enc_t = dec_t = obs_t = act_t = []
slow_n = blend_n = 0
prev_end = time.time()
gc_timer += CONTROL_DT
if gc_timer >= 10.0:
gc.collect()
gc_timer = 0.0
loop_t.append(loop_ms)
time.sleep(max(0.0, CONTROL_DT - (time.time() - t0)))
finally:
if log_ctx:
log_ctx.close()
except KeyboardInterrupt:
pass
finally:
gc.enable()
if args.log_csv:
print(f"\n[Log] Saved to {log_path}")
runtime.shutdown()
print("\nStopping...")
if robot.is_connected:
robot.disconnect()
print("Done.")
if __name__ == "__main__":
main()
@@ -54,7 +54,6 @@ from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
@@ -275,11 +274,12 @@ class LanguageColumnsWriter:
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
# Re-emit one row group per episode (a bulk pq.write_table would collapse
# them into one). Write to a sibling tmp path and atomically rename so a
# crash mid-write can't leave a half-written shard.
# Atomic replace: write to a sibling tmp path and rename so a crash
# mid-write can't leave a half-written shard that ``pq.read_table``
# would then fail to open. ``Path.replace`` is atomic on POSIX +
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp")
write_table_one_row_group_per_episode(new_table, tmp_path)
pq.write_table(new_table, tmp_path)
tmp_path.replace(path)
def _materialize_table(
-9
View File
@@ -32,7 +32,6 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
from .io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
to_parquet_one_row_group_per_episode,
to_parquet_with_hf_images,
write_info,
write_stats,
@@ -552,7 +551,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
aggr_root=dst_meta.root,
hf_features=hf_features,
concatenate=concatenate_data,
one_row_group_per_episode=True,
)
# Record the mapping from source to actual destination
@@ -630,7 +628,6 @@ def append_or_create_parquet_file(
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
concatenate: bool = True,
one_row_group_per_episode: bool = False,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -648,8 +645,6 @@ def append_or_create_parquet_file(
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one.
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
the episodes-metadata parquet (already one row per episode).
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -662,8 +657,6 @@ def append_or_create_parquet_file(
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(df, dst_path)
else:
df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file)
@@ -690,8 +683,6 @@ def append_or_create_parquet_file(
if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(final_df, target_path)
else:
final_df.to_parquet(target_path)
+9 -38
View File
@@ -20,7 +20,6 @@ import datasets
import numpy as np
import pandas
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
@@ -271,49 +270,21 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
return items_dict
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
"""Write ``table`` with one parquet row group per episode (in episode order).
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
mirroring the recording writer. ``table`` must carry a contiguous
``episode_index`` column.
"""
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
try:
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
finally:
writer.close()
def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Images are embedded into the arrow table first (``ParquetWriter.write_table``
does not embed external image files like ``Dataset.to_parquet`` does).
``features`` types image columns as ``Image()`` in the parquet schema.
Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds = embed_images(ds)
table = ds.with_format("arrow")[:]
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
# No episode boundaries to align row groups to — keep a single write.
pq.write_table(table, str(path))
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
table = pa.Table.from_pandas(df, preserve_index=False)
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
pq.write_table(table, str(path))
ds.to_parquet(path)
def item_to_torch(item: dict) -> dict:
+3 -5
View File
@@ -70,21 +70,19 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]],
*,
use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `exclude_images` and `patterns`, and finally
(image or state), filters them based on `use_videos` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset.
Args:
pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
exclude_images: If True, image features are dropped entirely from the output.
use_videos: If False, image features are excluded.
patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter.
@@ -122,7 +120,7 @@ def aggregate_pipeline_dataset_features(
)
# 2. Apply filtering rules.
if is_image and exclude_images:
if is_image and not use_videos:
continue
if not is_image and not should_keep(key, compiled_patterns):
continue
+20
View File
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "camera_obs" in observations:
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
# Pass through any remaining ndarray/tensor keys not already handled above,
# so env plugins can expose extra observation keys via get_env_processors().
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
for key, value in observations.items():
if key in _handled:
continue
target = f"{OBS_STR}.{key}"
if target in return_observations:
continue
if isinstance(value, np.ndarray):
val = torch.from_numpy(value).float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
elif isinstance(value, Tensor):
val = value.float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
return return_observations
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__)
class BiOpenArmFollower(BimanualMixin, Robot):
class BiOpenArmFollower(Robot):
"""
Bimanual OpenArm Follower Arms
"""
@@ -40,17 +39,15 @@ class BiOpenArmFollower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
# will only open the cameras assigned to it. Per-arm cameras are used
# as fallback when top-level cameras are empty.
if config.cameras:
left_cameras = config.cameras
right_cameras = {}
else:
left_cameras = config.left_arm_config.cameras
right_cameras = config.right_arm_config.cameras
left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None,
@@ -59,7 +56,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras,
cameras=left_cameras,
side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
@@ -78,7 +75,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras,
cameras=right_cameras,
side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
@@ -98,19 +95,22 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@property
def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return {
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
}
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
# "right_wrist"), so we merge them directly — unlike motors which need the
# left_/right_ prefix to disambiguate identical per-arm joint names.
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -120,6 +120,27 @@ class BiOpenArmFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -127,15 +148,21 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Camera keys that should NOT get the arm prefix (they already have unique names)
left_cam_keys = set(self.left_arm.cameras.keys())
right_cam_keys = set(self.right_arm.cameras.keys())
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
right_obs = self.right_arm.get_observation()
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict
@@ -162,4 +189,9 @@ class BiOpenArmFollower(BimanualMixin, Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -32,7 +32,5 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
# Top-level cameras shared across both arms.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__)
class BiRebotB601Follower(BimanualMixin, Robot):
class BiRebotB601Follower(Robot):
"""Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
@@ -42,18 +41,6 @@ class BiRebotB601Follower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -62,7 +49,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras,
cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
@@ -99,12 +86,10 @@ class BiRebotB601Follower(BimanualMixin, Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
return {
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -114,13 +99,32 @@ class BiRebotB601Follower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
for k, v in self.left_arm.get_observation().items():
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm.get_observation().items():
obs_dict[f"right_{k}"] = v
obs_dict = {}
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
return obs_dict
@check_if_not_connected
@@ -139,3 +143,8 @@ class BiRebotB601Follower(BimanualMixin, Robot):
**{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()},
}
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from dataclasses import dataclass
from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig
@@ -29,8 +27,3 @@ class BiRebotB601FollowerConfig(RobotConfig):
left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from ..so_follower import SOFollower, SOFollowerRobotConfig
@@ -28,7 +27,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__)
class BiSOFollower(BimanualMixin, Robot):
class BiSOFollower(Robot):
"""
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -40,18 +39,6 @@ class BiSOFollower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = SOFollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -59,7 +46,7 @@ class BiSOFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
use_degrees=config.left_arm_config.use_degrees,
cameras=left_arm_cameras,
cameras=config.left_arm_config.cameras,
)
right_arm_config = SOFollowerRobotConfig(
@@ -90,12 +77,13 @@ class BiSOFollower(BimanualMixin, Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
left_arm_cameras_ft = self.left_arm._cameras_ft
right_arm_cameras_ft = self.right_arm._cameras_ft
return {
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -105,21 +93,42 @@ class BiSOFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Add "left_" prefix
left_obs = self.left_arm.get_observation()
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
right_obs = self.right_arm.get_observation()
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
return obs_dict
@@ -142,3 +151,8 @@ class BiSOFollower(BimanualMixin, Robot):
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from dataclasses import dataclass
from ..config import RobotConfig
from ..so_follower import SOFollowerConfig
@@ -29,8 +27,3 @@ class BiSOFollowerConfig(RobotConfig):
left_arm_config: SOFollowerConfig
right_arm_config: SOFollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -68,6 +68,6 @@ class UnitreeG1Config(RobotConfig):
# Compensates for gravity on the unitree's arms using the arm ik solver
gravity_compensation: bool = False
# Locomotion controller class name, e.g. "GrootLocomotionController",
# "HolosomaLocomotionController", or "SonicWholeBodyController". None disables it.
# Lower-body controller class name, e.g. "GrootLocomotionController" or
# "HolosomaLocomotionController". None disables it.
controller: str | None = None
@@ -1,8 +0,0 @@
"""Unitree G1 locomotion controllers (Groot, Holosoma, SONIC)."""
__all__ = [
"GrootLocomotionController",
"HolosomaLocomotionController",
"SonicWholeBodyController",
"SonicRuntime",
]
@@ -1,913 +0,0 @@
"""SONIC planner pipeline: ONNX enc/dec/planner, movement state, and input helpers."""
import math
import queue
import select
import struct
import sys
import termios
import threading
import time
import tty
from dataclasses import dataclass
from enum import IntEnum
import numpy as np
import onnxruntime as ort
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
# ── Constants ────────────────────────────────────────────────────────────────
DEFAULT_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
0.0, 0.0, 0.0,
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
], dtype=np.float32)
NATURAL_FREQ = 10.0 * 2.0 * np.pi
ARMATURE = {"5020": 0.003609725, "7520_14": 0.010177520, "7520_22": 0.025101925, "4010": 0.00425}
EFFORT = {"5020": 25.0, "7520_14": 88.0, "7520_22": 139.0, "4010": 5.0}
def _action_scale(k):
return 0.25 * EFFORT[k] / (ARMATURE[k] * NATURAL_FREQ**2)
_J = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
["7520_14","5020","5020"] + \
["5020","5020","5020","5020","5020","4010","4010"] * 2
ACTION_SCALE = np.array([_action_scale(k) for k in _J], dtype=np.float32)
CONTROL_DT = 0.02
DEFAULT_HEIGHT = 0.788740
TOKEN_DIM = 64
ENCODER_UPDATE_EVERY = 5
DEBUG_PRINT_EVERY = 100
MOTION_LOOK_AHEAD_STEPS = 2
INITIAL_RANDOM_SEED = 1234
MIN_TOKENS, MAX_TOKENS = 6, 16
K = MAX_TOKENS - MIN_TOKENS + 1
DEADZONE = 0.05
BLEND_FRAMES = 8
REPLAN_INTERVAL = {
"running": 0.1, "crawling": 0.2, "boxing": 1.0, "default": 1.0
}
ISAACLAB_TO_MUJOCO = np.array([
0, 3, 6, 9, 13, 17, 1, 4, 7, 10, 14, 18, 2, 5, 8,
11, 15, 19, 21, 23, 25, 27, 12, 16, 20, 22, 24, 26, 28
], dtype=np.int32)
MUJOCO_TO_ISAACLAB = np.array([
0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 22, 4, 10,
16, 23, 5, 11, 17, 24, 18, 25, 19, 26, 20, 27, 21, 28
], dtype=np.int32)
def _to_mujoco(a): return a[MUJOCO_TO_ISAACLAB]
def _to_runtime(a): r = np.zeros(29, np.float32); r[MUJOCO_TO_ISAACLAB] = a; return r
DEFAULT_ANGLES_MUJOCO = _to_mujoco(DEFAULT_ANGLES)
ENCODER_STANDING_REF = DEFAULT_ANGLES.copy()
LOWER_BODY_IL = np.array([0,3,6,9,13,17,1,4,7,10,14,18], dtype=np.int32)
WRIST_IL = np.array([23,24,25,26,27,28], dtype=np.int32)
VR_TARGET_DEF = np.zeros(9, dtype=np.float32)
VR_ORN_DEF = np.array([1,0,0,0,1,0,0,0,1,0,0,0], dtype=np.float32)
SMPL_DEF = np.zeros(720, dtype=np.float32)
# ── PD gains ─────────────────────────────────────────────────────────────────
def compute_kp_kd():
s = lambda k: ARMATURE[k] * NATURAL_FREQ**2
d = lambda k: 2.0 * 2.0 * ARMATURE[k] * NATURAL_FREQ
_kp_keys = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
["7520_14","5020","5020"] + \
["5020","5020","5020","5020","5020","4010","4010"] * 2
_kd_keys = _kp_keys
_double = {4,5,10,11,13,14} # ankle + waist indices with factor 2
kp = np.array([2*s(k) if i in _double else s(k) for i,k in enumerate(_kp_keys)], dtype=np.float32)
kd = np.array([2*d(k) if i in _double else d(k) for i,k in enumerate(_kd_keys)], dtype=np.float32)
return kp, kd
_kp_kd = compute_kp_kd # backward-compatible alias
def lowstate_to_obs(lowstate) -> dict:
"""Build a robot observation dict from Unitree lowstate."""
obs: dict = {}
for motor in G1_29_JointIndex:
idx = motor.value
obs[f"{motor.name}.q"] = float(lowstate.motor_state[idx].q)
obs[f"{motor.name}.dq"] = float(lowstate.motor_state[idx].dq)
quat = lowstate.imu_state.quaternion
obs["imu.quat.w"] = float(quat[0])
obs["imu.quat.x"] = float(quat[1])
obs["imu.quat.y"] = float(quat[2])
obs["imu.quat.z"] = float(quat[3])
gyro = lowstate.imu_state.gyroscope
obs["imu.gyro.x"] = float(gyro[0])
obs["imu.gyro.y"] = float(gyro[1])
obs["imu.gyro.z"] = float(gyro[2])
wr = getattr(lowstate, "wireless_remote", None)
if wr is not None:
obs["wireless_remote"] = bytes(wr) if not isinstance(wr, (bytes, bytearray)) else wr
return obs
# ── Quaternion helpers ────────────────────────────────────────────────────────
def quat_conj(q):
return np.array([q[0], -q[1], -q[2], -q[3]], dtype=np.float32)
def quat_mul(q1, q2):
w1,x1,y1,z1 = q1; w2,x2,y2,z2 = q2
return np.array([
w1*w2 - x1*x2 - y1*y2 - z1*z2,
w1*x2 + x1*w2 + y1*z2 - z1*y2,
w1*y2 - x1*z2 + y1*w2 + z1*x2,
w1*z2 + x1*y2 - y1*x2 + z1*w2,
], dtype=np.float32)
def gravity_dir(q):
q = q / (np.linalg.norm(q) + 1e-8)
qv = np.array([0, 0, 0, -1], dtype=np.float32)
return quat_mul(quat_mul(quat_conj(q), qv), q)[1:]
def quat_to_6d(q):
w,x,y,z = q
return np.array([
1-2*(y*y+z*z), 2*(x*y-z*w),
2*(x*y+z*w), 1-2*(x*x+z*z),
2*(x*z-y*w), 2*(y*z+x*w),
], dtype=np.float32)
def calc_heading(q):
w,x,y,z = q
return float(np.arctan2(2*(x*y + w*z), 1-2*(y*y+z*z)))
def heading_quat(q, sign=1.0):
a = sign * calc_heading(q) / 2.0
return np.array([np.cos(a), 0, 0, np.sin(a)], dtype=np.float64)
heading_quat_inv = lambda q: heading_quat(q, -1.0)
def quat_slerp(q0, q1, t):
q0 = q0 / (np.linalg.norm(q0)+1e-12); q1 = q1 / (np.linalg.norm(q1)+1e-12)
dot = float(np.dot(q0, q1))
if dot < 0: q1, dot = -q1, -dot
dot = min(dot, 1.0)
if dot > 0.9995:
r = q0 + t*(q1-q0); return r/(np.linalg.norm(r)+1e-12)
th = np.arccos(dot); st = np.sin(th)
return (np.sin((1-t)*th)/st)*q0 + (np.sin(t*th)/st)*q1
def quat_slerp_batch(q0, q1, t):
q0 = q0 / (np.linalg.norm(q0,axis=1,keepdims=True)+1e-12)
q1 = q1 / (np.linalg.norm(q1,axis=1,keepdims=True)+1e-12)
dot = np.sum(q0*q1, axis=1); neg = dot<0
q1=q1.copy(); q1[neg]=-q1[neg]; dot[neg]=-dot[neg]; dot=np.clip(dot,-1,1)
lin = dot>0.9995; th=np.arccos(dot); st=np.where(np.sin(th)==0,1,np.sin(th))
c0=np.sin((1-t)*th)/st; c1=np.sin(t*th)/st
c0[lin]=1-t[lin]; c1[lin]=t[lin]
r = c0[:,None]*q0 + c1[:,None]*q1
return r / (np.linalg.norm(r,axis=1,keepdims=True)+1e-12)
# ── Locomotion modes ──────────────────────────────────────────────────────────
class LocomotionMode(IntEnum):
IDLE=0; SLOW_WALK=1; WALK=2; RUN=3; SQUAT=4; KNEEL_TWO_LEGS=5; KNEEL=6
LYING_FACE_DOWN=7; CRAWLING=8; IDLE_BOXING=9; WALK_BOXING=10
LEFT_PUNCH=11; RIGHT_PUNCH=12; RANDOM_PUNCH=13; ELBOW_CRAWLING=14
LEFT_HOOK=15; RIGHT_HOOK=16; FORWARD_JUMP=17; STEALTH_WALK=18
INJURED_WALK=19; LEDGE_WALKING=20; OBJECT_CARRYING=21; STEALTH_WALK_2=22
HAPPY_DANCE_WALK=23; ZOMBIE_WALK=24; GUN_WALK=25; SCARE_WALK=26
LM = LocomotionMode
MOTION_SETS = [
("Standing", [LM.SLOW_WALK, LM.WALK, LM.RUN, LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK]),
("Squat / Low", [LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.CRAWLING, LM.ELBOW_CRAWLING]),
("Boxing", [LM.IDLE_BOXING, LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK]),
("Styled Walks", [LM.LEDGE_WALKING, LM.OBJECT_CARRYING, LM.STEALTH_WALK_2,
LM.HAPPY_DANCE_WALK, LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK]),
]
STATIC_MODES = {LM.IDLE, LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.LYING_FACE_DOWN, LM.IDLE_BOXING}
STANDING_MODES = {LM.IDLE, LM.SLOW_WALK, LM.WALK, LM.RUN, LM.IDLE_BOXING, LM.WALK_BOXING,
LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK,
LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK, LM.LEDGE_WALKING,
LM.OBJECT_CARRYING, LM.STEALTH_WALK_2, LM.HAPPY_DANCE_WALK,
LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK}
BOXING_MODES = {LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}
SPEED_RANGES = {LM.SLOW_WALK:(0.2,0.8), LM.WALK:(0.8,1.5), LM.RUN:(1.5,3.0),
LM.CRAWLING:(0.4,1.0), LM.ELBOW_CRAWLING:(0.7,1.0)}
def clamp_mode_params(ms):
m = LM(ms.mode)
ms.height = -1.0 if m in STANDING_MODES else max(0.1, min(0.8, ms.height if ms.height>=0 else 0.2))
if m in STATIC_MODES:
ms.speed = -1.0
elif m in SPEED_RANGES:
lo, hi = SPEED_RANGES[m]
ms.speed = max(lo, min(hi, ms.speed if ms.speed>=0 else lo))
elif m in BOXING_MODES:
ms.speed = max(0.7, min(1.5, ms.speed if ms.speed>=0 else 0.7))
else:
ms.speed = -1.0
def replan_interval(mode):
m = LM(mode)
if m == LM.RUN: return REPLAN_INTERVAL["running"]
if m == LM.CRAWLING: return REPLAN_INTERVAL["crawling"]
if m in {LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}:
return REPLAN_INTERVAL["boxing"]
return REPLAN_INTERVAL["default"]
def _ort_providers(force_cpu: bool = False) -> list[str]:
"""Prefer CUDA for enc/dec/planner (matches deploy when onnxruntime-gpu is installed)."""
avail = ort.get_available_providers()
if not force_cpu and "CUDAExecutionProvider" in avail:
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
return ["CPUExecutionProvider"]
# ── Movement state ────────────────────────────────────────────────────────────
@dataclass
class MovementState:
mode: int = LM.SLOW_WALK # not IDLE — walking modes respond to WASD
speed: float = -1.0
height: float = -1.0
facing_angle: float = 0.0
movement_angle: float = 0.0
has_movement: bool = False
motion_set_idx: int = 0
needs_replan: bool = False
@property
def movement_direction(self):
if not self.has_movement: return (0.0, 0.0, 0.0)
return (math.cos(self.movement_angle), math.sin(self.movement_angle), 0.0)
@property
def facing_direction(self):
return (math.cos(self.facing_angle), math.sin(self.facing_angle), 0.0)
def status_line(self):
return (f"[{MOTION_SETS[self.motion_set_idx][0]}] mode={self.mode}({LM(self.mode).name}) "
f"spd={'default' if self.speed<0 else f'{self.speed:.1f}'} "
f"hgt={'default' if self.height<0 else f'{self.height:.2f}'} "
f"facing={math.degrees(self.facing_angle):.0f}° "
f"{'moving' if self.has_movement else 'still'}")
@dataclass
class MovementSnapshot:
mode: int = 0
speed: float = -1.0
height: float = -1.0
movement: tuple[float, float, float] = (0.0, 0.0, 0.0)
facing: tuple[float, float, float] = (1.0, 0.0, 0.0)
def _snapshot_ms(ms: MovementState) -> MovementSnapshot:
md, fd = ms.movement_direction, ms.facing_direction
return MovementSnapshot(ms.mode, ms.speed, ms.height, (md[0], md[1], md[2]), (fd[0], fd[1], fd[2]))
def should_replan_request(ms: MovementState, last: MovementSnapshot, replan_timer: float, step: int) -> bool:
"""Match C++ G1Deploy::Planner replan triggers (g1_deploy_onnx_ref.cpp)."""
if step <= 0:
return False
if ms.needs_replan:
return True
md, fd = ms.movement_direction, ms.facing_direction
facing_changed = fd != last.facing
height_changed = ms.height != last.height
mode_changed = ms.mode != last.mode
speed_changed = ms.speed != last.speed
dir_changed = md != last.movement
is_static = LM(ms.mode) in STATIC_MODES
if mode_changed or facing_changed or height_changed:
return True
time_to_replan = replan_timer >= replan_interval(ms.mode)
if not is_static and (speed_changed or dir_changed or (time_to_replan and ms.speed != 0)):
return True
return False
# ── Encoder / Decoder ─────────────────────────────────────────────────────────
class StandingEncoderDecoder:
def __init__(self, encoder, decoder):
self.encoder, self.decoder = encoder, decoder
self.encoder_input = encoder.get_inputs()[0].name
self.decoder_input = decoder.get_inputs()[0].name
enc_dim = int(encoder.get_inputs()[0].shape[1])
dec_dim = int(decoder.get_inputs()[0].shape[1])
if enc_dim != 1762 or dec_dim != 994:
raise RuntimeError(f"Unexpected dims encoder={enc_dim}, decoder={dec_dim}")
self.token = np.zeros(TOKEN_DIM, np.float32)
self.last_action_mj = np.zeros(29, np.float32)
self.h_q_mj = [np.zeros(29, np.float32)] * 10
self.h_dq_mj = [np.zeros(29, np.float32)] * 10
self.h_ang = [np.zeros(3, np.float32)] * 10
self.h_act_mj = [np.zeros(29, np.float32)] * 10
self.h_quat = [np.array([1,0,0,0], np.float32)] * 10
self.init_base_quat = np.array([1,0,0,0], np.float32)
self.init_ref_quat = np.array([1,0,0,0], np.float32)
self._heading_init = False
self.encode_mode = 0
self.vr_3point_local_target = VR_TARGET_DEF.copy()
self.vr_3point_local_orn_target = VR_ORN_DEF.copy()
self.smpl_joints_10frame_step1 = SMPL_DEF.copy()
self.set_zero_reference()
def update_history(self, q, dq, ang, quat):
quat = quat / (np.linalg.norm(quat)+1e-8)
q_mj = _to_mujoco(q); dq_mj = _to_mujoco(dq)
self.h_q_mj = [q_mj - DEFAULT_ANGLES_MUJOCO] + self.h_q_mj[:-1]
self.h_dq_mj = [dq_mj] + self.h_dq_mj[:-1]
self.h_ang = [ang.copy()] + self.h_ang[:-1]
self.h_act_mj = [self.last_action_mj.copy()] + self.h_act_mj[:-1]
self.h_quat = [quat.copy()] + self.h_quat[:-1]
if not self._heading_init:
self.init_base_quat = quat.copy(); self._heading_init = True
def _heading_quat(self, q):
h = calc_heading(q) / 2.0
return np.array([np.cos(h), 0, 0, np.sin(h)], np.float32)
def _heading_quat_inv(self, q):
h = calc_heading(q) / 2.0
return np.array([np.cos(-h), 0, 0, np.sin(-h)], np.float32)
def _anchor_6d(self, base_quat, ref_quat=None):
if ref_quat is None: ref_quat = self.init_ref_quat
delta = quat_mul(self._heading_quat(self.init_base_quat), self._heading_quat_inv(self.init_ref_quat))
new_ref = quat_mul(delta, ref_quat)
return quat_to_6d(quat_mul(quat_conj(base_quat), new_ref))
def set_zero_reference(self):
self.motion_joint_positions = [ENCODER_STANDING_REF.copy()]
self.motion_joint_velocities = [np.zeros(29, np.float32)]
self.motion_body_quats = [np.array([1,0,0,0], np.float32)]
self.motion_body_z = [DEFAULT_HEIGHT]
self.motion_timesteps = 1
self.freeze_ref_frame = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
def build_encoder_obs(self):
obs = np.zeros(1762, np.float32)
obs[0] = float(self.encode_mode)
rf = min(self.freeze_ref_frame, self.motion_timesteps - 1)
ref_pos, ref_quat = self.motion_joint_positions[rf], self.motion_body_quats[rf]
if self.encode_mode == 0:
for f in range(10):
obs[4+29*f:4+29*(f+1)] = ref_pos
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
elif self.encode_mode == 1:
ref_lower = ref_pos[LOWER_BODY_IL]
for f in range(10):
obs[661+12*f:661+12*(f+1)] = ref_lower
obs[901:910] = self.vr_3point_local_target
obs[910:922] = self.vr_3point_local_orn_target
obs[595:601] = self._anchor_6d(self.h_quat[0], ref_quat)
elif self.encode_mode == 2:
obs[922:1642] = self.smpl_joints_10frame_step1
for f in range(10):
obs[1642+6*f:1642+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
obs[1702+6*f:1702+6*(f+1)] = ref_pos[WRIST_IL]
else:
raise RuntimeError(f"Unsupported encoder mode: {self.encode_mode}")
return obs
def build_decoder_obs(self):
obs = np.zeros(994, np.float32); off = 0
obs[off:off+64] = self.token; off += 64
for h, sz in [(list(reversed(self.h_ang)),3), (list(reversed(self.h_q_mj)),29),
(list(reversed(self.h_dq_mj)),29), (list(reversed(self.h_act_mj)),29)]:
for f in range(10): obs[off:off+sz] = h[f]; off += sz
for q in reversed(self.h_quat):
obs[off:off+3] = gravity_dir(q); off += 3
assert off == 994, f"Decoder obs mismatch: {off}"
return obs
def run_encoder(self):
return self.encoder.run(None, {self.encoder_input: self.build_encoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
def step(self, robot_obs, update_encoder, debug=False):
jnames = [m.name for m in G1_29_JointIndex]
q = np.array([robot_obs.get(f"{n}.q", DEFAULT_ANGLES[m.value]) for m,n in zip(G1_29_JointIndex,jnames)], np.float32)
dq = np.array([robot_obs.get(f"{n}.dq", 0.0) for n in jnames], np.float32)
quat = np.array([robot_obs.get("imu.quat.w",1), robot_obs.get("imu.quat.x",0),
robot_obs.get("imu.quat.y",0), robot_obs.get("imu.quat.z",0)], np.float32)
ang = np.array([robot_obs.get(f"imu.gyro.{a}",0) for a in "xyz"], np.float32)
self.update_history(q, dq, ang, quat)
if update_encoder: self.token = self.run_encoder()
action_mj = self.decoder.run(None, {self.decoder_input: self.build_decoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
self.last_action_mj = action_mj.copy()
target = DEFAULT_ANGLES + action_mj[ISAACLAB_TO_MUJOCO] * ACTION_SCALE
if debug:
delta = target - q
print(f"token_norm={np.linalg.norm(self.token):.4f} action_norm={np.linalg.norm(action_mj):.4f} "
f"delta_max={np.max(np.abs(delta)):.4f} delta_rms={np.sqrt(np.mean(delta**2)):.4f}")
return {f"{m.name}.q": float(target[m.value]) for m in G1_29_JointIndex}
def print_input_diagnostics(self):
print("\n[Diag] Standing reference checks")
names = {0:"g1", 1:"teleop", 2:"smpl"}
print(f" encoder mode: {self.encode_mode} ({names.get(self.encode_mode,'unknown')})")
print(f" DEFAULT_ANGLES range: [{DEFAULT_ANGLES.min():+.4f}, {DEFAULT_ANGLES.max():+.4f}]")
print(f" anchor_6d(identity): {self._anchor_6d(np.array([1,0,0,0],np.float32), np.array([1,0,0,0],np.float32))}")
print(f" gravity(identity): {gravity_dir(np.array([1,0,0,0],np.float32))} (expect [0,0,-1])")
dec0 = self.build_decoder_obs()
print(f" decoder q-delta max: {np.max(np.abs(dec0[94:384])):.6f}")
print(f" decoder dq max: {np.max(np.abs(dec0[384:674])):.6f}")
# ── Planner motion buffer ─────────────────────────────────────────────────────
class PlannerMotion:
def __init__(self, max_frames=1500):
self.timesteps = 0
self.joint_positions = np.zeros((max_frames, 29), np.float64)
self.joint_velocities = np.zeros((max_frames, 29), np.float64)
self.body_positions = np.zeros((max_frames, 3), np.float64)
self.body_quaternions = np.zeros((max_frames, 4), np.float64)
self.body_quaternions[:, 0] = 1.0
# ── Subprocess planner ────────────────────────────────────────────────────────
def _resample_30_to_50(qpos, n30):
t50 = int(np.floor(n30 / 30.0 * 50))
f30 = np.arange(t50) / 50.0 * 30.0
f0 = np.floor(f30).astype(int)
f1 = np.minimum(f0+1, n30-1)
frac, w0 = (f30-f0).astype(np.float64), None
w0 = 1.0 - frac
jp = (w0[:,None]*qpos[f0,7:36] + frac[:,None]*qpos[f1,7:36])[:,MUJOCO_TO_ISAACLAB]
jv = np.zeros_like(jp)
if t50 >= 2: jv[:t50-1] = (jp[1:] - jp[:-1]) * 50.0; jv[-1] = jv[-2]
return {
"timesteps": t50,
"joint_positions": jp,
"joint_velocities": jv,
"body_positions": w0[:,None]*qpos[f0,:3] + frac[:,None]*qpos[f1,:3],
"body_quaternions": quat_slerp_batch(qpos[f0,3:7], qpos[f1,3:7], frac),
}
def _build_planner_inputs(ctx, ms_dict, version, seed):
inp = {
"context_mujoco_qpos": ctx.astype(np.float32).reshape(1,4,36),
"target_vel": np.array([ms_dict["speed"]], np.float32),
"mode": np.array([ms_dict["mode"]], np.int64),
"movement_direction": np.array(ms_dict["movement_direction"], np.float32).reshape(1,3),
"facing_direction": np.array(ms_dict["facing_direction"], np.float32).reshape(1,3),
"random_seed": np.array([seed], np.int64),
}
if version >= 1:
# TensorRT deploy: allow 911 prediction tokens only (indices 35 for MIN_TOKENS=6).
allowed = np.zeros((1, K), np.int64)
if K >= 6:
allowed[0, 3:6] = 1
inp.update({
"height": np.array([ms_dict["height"]], np.float32),
"has_specific_target": np.array([[0]], np.int64),
"specific_target_positions": np.zeros((1,4,3), np.float32),
"specific_target_headings": np.zeros((1,4), np.float32),
"allowed_pred_num_tokens": allowed,
})
return inp
def _planner_worker(path, req_q, res_q, stop_evt, version, seed, use_gpu):
so = ort.SessionOptions(); so.log_severity_level = 3
providers = _ort_providers(force_cpu=not use_gpu)
sess = ort.InferenceSession(path, sess_options=so, providers=providers)
while not stop_evt.is_set():
try: ctx, gf, ms_dict = req_q.get(timeout=0.05)
except Exception: continue
try:
inp = _build_planner_inputs(ctx, ms_dict, version, seed)
t0 = time.time()
qpos_out, num_pred = sess.run(None, inp)
t_inf = time.time()
n = int(num_pred.flat[0])
qpos = qpos_out[0,:n]
if np.any(np.isnan(qpos)): continue
motion = _resample_30_to_50(qpos, n)
motion["gen_frame"] = gf
print(f"[Planner] inf={1000*(t_inf-t0):.1f}ms total={1000*(time.time()-t0):.1f}ms frames={n}", flush=True)
while not res_q.empty():
try: res_q.get_nowait()
except queue.Empty: break
res_q.put(motion)
except Exception as e:
print(f"[Planner] Error: {e}", flush=True)
# ── SonicPlanner ──────────────────────────────────────────────────────────────
class SonicPlanner:
def __init__(self, session, planner_path):
self.session = session
self.planner_path = planner_path
self.gen_frame = 0
self.random_seed = INITIAL_RANDOM_SEED
self.version = 1 if len(session.get_inputs()) >= 11 else 0
self.motion_50hz = PlannerMotion()
self._snapshot = PlannerMotion()
self._req_q = self._res_q = self._stop_evt = self._planner_thread = None
self._ctrl = None
def _build_inputs(self, ctx, ms):
return _build_planner_inputs(
ctx,
{"mode": ms.mode, "speed": ms.speed, "height": ms.height,
"movement_direction": list(ms.movement_direction),
"facing_direction": list(ms.facing_direction)},
self.version, self.random_seed)
@staticmethod
def build_initial_context(joint_positions):
ctx = np.zeros((4, 36), np.float32)
jp_mj = joint_positions.astype(np.float32)[ISAACLAB_TO_MUJOCO]
for n in range(4):
ctx[n, 2] = DEFAULT_HEIGHT
ctx[n, 3] = 1.0
ctx[n, 7:36] = jp_mj
return ctx
def _context_from_controller(self, current_frame):
ctrl = self._ctrl
gen_frame = current_frame + MOTION_LOOK_AHEAD_STEPS
t_arr = gen_frame / 50.0 + np.arange(4) / 30.0
f50 = t_arr * 50.0
with ctrl.motion_lock:
ts = ctrl.motion_timesteps
if ts <= 0:
return self.build_initial_context(DEFAULT_ANGLES)
bp, bq, jp = ctrl.motion_body_pos, ctrl.motion_body_quats, ctrl.motion_joint_positions
f0 = np.minimum(np.floor(f50).astype(int), ts - 1)
f1 = np.minimum(f0 + 1, ts - 1)
frac = f50 - f0
w0 = 1.0 - frac
ctx = np.zeros((4, 36), np.float32)
ctx[:, 0:3] = w0[:, None] * bp[f0] + frac[:, None] * bp[f1]
ctx[:, 3:7] = quat_slerp_batch(bq[f0], bq[f1], frac)
ij = w0[:, None] * jp[f0] + frac[:, None] * jp[f1]
ctx[:, 7:36] = ij[:, ISAACLAB_TO_MUJOCO]
self.gen_frame = gen_frame
return ctx
def _load_motion_in_place(self, qpos, n30, target=None):
if target is None: target = self.motion_50hz
r = _resample_30_to_50(qpos, n30)
n = r["timesteps"]; target.timesteps = n
target.joint_positions[:n] = r["joint_positions"]
target.joint_velocities[:n] = r["joint_velocities"]
target.body_positions[:n] = r["body_positions"]
target.body_quaternions[:n] = r["body_quaternions"]
return target
def initialize(self, joint_positions, ms):
ctx = self.build_initial_context(joint_positions)
qpos_out, num_pred = self.session.run(None, self._build_inputs(ctx, ms))
n = int(num_pred.flat[0]); qpos = qpos_out[0,:n]
if np.any(np.isnan(qpos)): raise RuntimeError("Planner initial output contains NaN")
print(f"[Planner] Init: {n} frames @ 30 Hz")
self._load_motion_in_place(qpos, n)
print(f"[Planner] Resampled to {self.motion_50hz.timesteps} frames @ 50 Hz")
return self.motion_50hz
def request_replan(self, cursor, ms):
if self._req_q is None: return
ctx = self._context_from_controller(cursor)
ms_dict = {"mode": ms.mode, "speed": ms.speed, "height": ms.height,
"movement_direction": list(ms.movement_direction),
"facing_direction": list(ms.facing_direction)}
while not self._req_q.empty():
try: self._req_q.get_nowait()
except queue.Empty: break
self._req_q.put((ctx, self.gen_frame, ms_dict))
def try_get_new_motion(self):
if self._res_q is None: return None
result = None
while not self._res_q.empty():
try: result = self._res_q.get_nowait()
except queue.Empty: break
if result is None: return None
n, gf = result["timesteps"], result["gen_frame"]
s = self._snapshot; s.timesteps = n
s.joint_positions[:n] = result["joint_positions"]
s.joint_velocities[:n] = result["joint_velocities"]
s.body_positions[:n] = result["body_positions"]
s.body_quaternions[:n] = result["body_quaternions"]
return s, gf
def start_subprocess(self, controller, use_gpu: bool = False):
"""Run planner ONNX in a background thread (avoids mp spawn/fork + CUDA/MuJoCo issues)."""
self._ctrl = controller
self._req_q = queue.Queue()
self._res_q = queue.Queue()
self._stop_evt = threading.Event()
self._planner_thread = threading.Thread(
target=_planner_worker,
args=(self.planner_path, self._req_q, self._res_q,
self._stop_evt, self.version, self.random_seed, use_gpu),
daemon=True,
name="sonic-planner",
)
self._planner_thread.start()
print(f"[Planner] Background thread started ({'GPU' if use_gpu else 'CPU'})")
def stop_subprocess(self):
if self._stop_evt:
self._stop_evt.set()
if self._planner_thread is not None:
self._planner_thread.join(timeout=3.0)
print("[Planner] Background thread stopped")
self._planner_thread = None
self._req_q = self._res_q = self._stop_evt = None
# ── PlannerController ─────────────────────────────────────────────────────────
class PlannerController(StandingEncoderDecoder):
def __init__(self, planner, encoder, decoder):
super().__init__(encoder, decoder)
self.planner = planner
self.ref_cursor = 0
self.motion_timesteps = 0
self.motion_joint_positions = np.zeros((1500,29), np.float64)
self.motion_joint_velocities = np.zeros((1500,29), np.float64)
self.motion_body_quats = np.zeros((1500,4), np.float64); self.motion_body_quats[:,0] = 1.0
self.motion_body_pos = np.zeros((1500,3), np.float64)
self.init_ref_quat = np.array([1,0,0,0], np.float64)
self.heading_init_base_quat = np.array([1,0,0,0], np.float64)
self.delta_heading = 0.0
self.reinit_heading = False
self.playing = self.first_motion = False
self.motion_lock = threading.Lock()
def load_initial_motion(self, motion):
with self.motion_lock:
n = motion.timesteps
self.motion_timesteps = n
self.motion_joint_positions[:n] = motion.joint_positions[:n]
self.motion_joint_velocities[:n] = motion.joint_velocities[:n]
self.motion_body_quats[:n] = motion.body_quaternions[:n]
self.motion_body_pos[:n] = motion.body_positions[:n]
self.init_ref_quat = motion.body_quaternions[0].copy()
self.ref_cursor = 0; self.first_motion = True
self.playing = True; self.delta_heading = 0.0
def blend_new_motion(self, new_motion, gen_frame):
"""Blend like C++ CurrentFrameAdvancement: 8-frame cross-fade, then copy tail."""
with self.motion_lock:
cur = self.ref_cursor
new_len = gen_frame - cur + new_motion.timesteps
if new_len <= 0:
return
if self.motion_timesteps == 0:
n = new_motion.timesteps
self.motion_joint_positions[:n] = new_motion.joint_positions[:n]
self.motion_joint_velocities[:n] = new_motion.joint_velocities[:n]
self.motion_body_pos[:n] = new_motion.body_positions[:n]
self.motion_body_quats[:n] = new_motion.body_quaternions[:n]
self.motion_timesteps = n
self.ref_cursor = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
self.first_motion = False
return
blend_start = max(0, gen_frame - cur)
blend_end = min(new_len, blend_start + BLEND_FRAMES)
for f in range(blend_end):
f_old = min(f + cur, self.motion_timesteps - 1)
f_new = max(0, min(f + cur - gen_frame, new_motion.timesteps - 1))
w_new = min(1.0, max(0.0, (f - blend_start) / BLEND_FRAMES))
w_old = 1.0 - w_new
self.motion_joint_positions[f] = (
w_old * self.motion_joint_positions[f_old]
+ w_new * new_motion.joint_positions[f_new]
)
self.motion_joint_velocities[f] = (
w_old * self.motion_joint_velocities[f_old]
+ w_new * new_motion.joint_velocities[f_new]
)
self.motion_body_pos[f] = (
w_old * self.motion_body_pos[f_old]
+ w_new * new_motion.body_positions[f_new]
)
self.motion_body_quats[f] = quat_slerp(
self.motion_body_quats[f_old], new_motion.body_quaternions[f_new], w_new
)
for f in range(blend_end, new_len):
f_new = max(0, min(f + cur - gen_frame, new_motion.timesteps - 1))
self.motion_joint_positions[f] = new_motion.joint_positions[f_new]
self.motion_joint_velocities[f] = new_motion.joint_velocities[f_new]
self.motion_body_pos[f] = new_motion.body_positions[f_new]
self.motion_body_quats[f] = new_motion.body_quaternions[f_new].copy()
self.motion_timesteps = new_len
self.first_motion = False
self.ref_cursor = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
def _heading_apply_delta(self):
delta = quat_mul(heading_quat(self.heading_init_base_quat).astype(np.float32),
heading_quat_inv(self.init_ref_quat).astype(np.float32))
if self.delta_heading:
h = self.delta_heading / 2.0
delta = quat_mul(np.array([np.cos(h),0,0,np.sin(h)], np.float32), delta)
return delta
def _anchor_6d(self, base_quat, ref_quat=None):
if ref_quat is None: ref_quat = self.init_ref_quat
new_ref = quat_mul(self._heading_apply_delta(), ref_quat.astype(np.float32))
return quat_to_6d(quat_mul(quat_conj(base_quat.astype(np.float32)), new_ref))
def build_encoder_obs(self):
obs = np.zeros(1762, np.float32); obs[0] = float(self.encode_mode)
with self.motion_lock:
for f in range(10):
tf = min(self.ref_cursor + f*5 if self.playing else self.ref_cursor,
self.motion_timesteps - 1)
obs[4+29*f:4+29*(f+1)] = self.motion_joint_positions[tf].astype(np.float32)
if self.playing:
obs[294+29*f:294+29*(f+1)] = self.motion_joint_velocities[tf].astype(np.float32)
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(
self.h_quat[0], self.motion_body_quats[tf].astype(np.float32))
return obs
def step(self, robot_obs, update_encoder, debug=False):
if robot_obs and (self.first_motion or self.reinit_heading):
q = None
if "imu.quat.w" in robot_obs:
q = np.array([
robot_obs["imu.quat.w"], robot_obs["imu.quat.x"],
robot_obs["imu.quat.y"], robot_obs["imu.quat.z"],
], np.float64)
else:
q = robot_obs.get("imu.quaternion")
if q is not None:
q = np.array(q, np.float64)
if q is not None:
self.heading_init_base_quat = np.array(q, np.float64)
with self.motion_lock:
rf = min(self.ref_cursor, self.motion_timesteps - 1)
self.init_ref_quat = self.motion_body_quats[rf].copy()
self.delta_heading = 0.0
self.first_motion = False
self.reinit_heading = False
print(f"[Heading] init quat: {self.heading_init_base_quat}")
return super().step(robot_obs, update_encoder=update_encoder, debug=debug)
def advance_cursor(self):
"""Advance one frame per 50 Hz tick (C++ current_frame_ += 1), no wall-clock catch-up."""
if not self.playing:
return
with self.motion_lock:
if self.motion_timesteps > 0:
self.ref_cursor = min(self.ref_cursor + 1, self.motion_timesteps - 1)
# ── Keyboard ──────────────────────────────────────────────────────────────────
class RawKeyboard:
def __init__(self):
self.fd = sys.stdin.fileno()
self.old = termios.tcgetattr(self.fd)
def __enter__(self): tty.setcbreak(self.fd); return self
def __exit__(self, *_): termios.tcsetattr(self.fd, termios.TCSADRAIN, self.old)
def get_key(self):
return sys.stdin.read(1) if select.select([sys.stdin],[],[],0)[0] else None
def drain_keyboard(kb, ms, controller=None) -> bool:
"""Process all pending terminal keys this frame (return True to quit)."""
quit_requested = False
while True:
key = kb.get_key()
if key is None:
break
if process_keyboard(key, ms, controller):
quit_requested = True
return quit_requested
def process_keyboard(key, ms, controller=None):
if key is None: return False
if key == '\x1b': return True
if key == ' ':
ms.mode = LM.IDLE; ms.speed = ms.height = -1.0
ms.has_movement = False; ms.needs_replan = True
if controller: controller.playing = False; controller.reinit_heading = True
print("\n >> EMERGENCY STOP -> IDLE"); return False
if key in ('r','R'):
ms.needs_replan = True; print("\n >> Manual replan"); return False
if key in ('n','N','p','P'):
ms.motion_set_idx = (ms.motion_set_idx + (1 if key in ('n','N') else -1)) % len(MOTION_SETS)
name, modes = MOTION_SETS[ms.motion_set_idx]
print(f"\n >> Motion set: {name}")
[print(f" {i+1}: {m.name}") for i,m in enumerate(modes)]
return False
if key.isdigit() and key not in ('9','0'):
idx = int(key) - 1; modes = MOTION_SETS[ms.motion_set_idx][1]
if 0 <= idx < len(modes):
ms.mode = modes[idx]; ms.needs_replan = True
if controller: controller.playing = True; controller.reinit_heading = True
print(f"\n >> Mode: {LM(ms.mode).name} ({ms.mode}) [replanning...]")
return False
if key == '9':
ms.speed = max(0.0, (ms.speed if ms.speed>=0 else 1.0) - 0.1)
print(f"\n >> Speed: {ms.speed:.1f}"); return False
if key == '0':
ms.speed = min(5.0, (ms.speed if ms.speed>=0 else 1.0) + 0.1)
print(f"\n >> Speed: {ms.speed:.1f}"); return False
if key == '-':
ms.height = max(0.2, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) - 0.02)
print(f"\n >> Height: {ms.height:.2f}"); return False
if key == '=':
ms.height = min(1.0, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) + 0.02)
print(f"\n >> Height: {ms.height:.2f}"); return False
if key.lower() == 'w': ms.movement_angle = ms.facing_angle
elif key.lower() == 's': ms.movement_angle = ms.facing_angle + math.pi
elif key.lower() == 'a': ms.movement_angle = ms.facing_angle + math.pi/2
elif key.lower() == 'd': ms.movement_angle = ms.facing_angle - math.pi/2
if key.lower() in ('w','s','a','d'):
ms.has_movement = ms.needs_replan = True
if controller:
controller.playing = True
print(f"\n >> Move {key.upper()} (replanning...)")
elif key.lower() == 'q':
ms.facing_angle += 0.1
if controller: controller.delta_heading += 0.1
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
elif key.lower() == 'e':
ms.facing_angle -= 0.1
if controller: controller.delta_heading -= 0.1
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
return False
_joy_prev_active = False
def _parse_wireless(wr):
"""Parse wireless_remote (bytes or int-array) into (lx, ly, rx, ry)."""
import struct as _st
if not isinstance(wr, (bytes, bytearray)):
wr = bytes(wr)
if len(wr) < 24:
return None
lx = _st.unpack("f", wr[4:8])[0]
rx = _st.unpack("f", wr[8:12])[0]
ry = _st.unpack("f", wr[12:16])[0]
ly = _st.unpack("f", wr[20:24])[0]
return lx, ly, rx, ry
def process_joystick(obs, ms, controller=None):
"""Joystick mirrors keyboard: left stick=WASD, right stick X=Q/E, right stick Y=height."""
global _joy_prev_active
wr = obs.get("wireless_remote")
if wr is None:
return
parsed = _parse_wireless(wr)
if parsed is None:
return
lx, ly, rx, ry = parsed
# Dead zone + negate both Y axes (bridge already flips them once)
lx = 0.0 if abs(lx) < DEADZONE else lx
ly = 0.0 if abs(ly) < DEADZONE else -ly
rx = 0.0 if abs(rx) < DEADZONE else rx
ry = 0.0 if abs(ry) < DEADZONE else -ry
left_active = abs(lx) > 0 or abs(ly) > 0
# Left stick → WASD (movement direction relative to facing)
if left_active:
ms.movement_angle = ms.facing_angle + math.atan2(-lx, -ly)
ms.has_movement = True
if not _joy_prev_active:
ms.needs_replan = True
_joy_prev_active = True
elif _joy_prev_active and not (abs(rx) > 0 or abs(ry) > 0):
_joy_prev_active = False
ms.has_movement = False
# Right stick X → Q/E (facing rotation, ~1 rad/s at full deflection)
if abs(rx) > 0:
delta = -0.02 * rx
ms.facing_angle += delta
if controller:
controller.delta_heading += delta
# Right stick Y → -/= (height adjustment, ~0.25/s at full deflection)
if abs(ry) > 0:
step = -0.005 * ry
ms.height = max(0.1, min(1.0, (ms.height if ms.height >= 0 else DEFAULT_HEIGHT) + step))
@@ -1,152 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SONIC full-body controller for Unitree G1."""
import logging
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.controllers.sonic_pipeline import (
CONTROL_DT,
DEBUG_PRINT_EVERY,
DEFAULT_ANGLES,
ENCODER_UPDATE_EVERY,
LM,
MOTION_SETS,
MovementState,
PlannerController,
SonicPlanner,
clamp_mode_params,
compute_kp_kd,
lowstate_to_obs,
process_joystick,
should_replan_request,
_ort_providers,
_snapshot_ms,
)
logger = logging.getLogger(__name__)
class SonicRuntime:
"""Shared SONIC control loop state (standalone demo + locomotion controller)."""
def __init__(self, force_cpu: bool = False):
planner_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="planner_sonic.onnx")
encoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_encoder.onnx")
decoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_decoder.onnx")
providers = _ort_providers(force_cpu=force_cpu)
self.use_gpu = providers[0] == "CUDAExecutionProvider"
so = ort.SessionOptions()
so.log_severity_level = 3
planner_sess = ort.InferenceSession(planner_path, sess_options=so, providers=providers)
encoder_sess = ort.InferenceSession(encoder_path, sess_options=so, providers=providers)
decoder_sess = ort.InferenceSession(decoder_path, sess_options=so, providers=providers)
self.kp, self.kd = compute_kp_kd()
self.ms = MovementState()
self.planner = SonicPlanner(planner_sess, planner_path)
self.controller = PlannerController(self.planner, encoder_sess, decoder_sess)
motion = self.planner.initialize(DEFAULT_ANGLES, self.ms)
self.controller.load_initial_motion(motion)
self.planner.start_subprocess(self.controller, use_gpu=self.use_gpu)
self.step = 0
self.replan_timer = 0.0
self.last_ms = _snapshot_ms(self.ms)
@property
def pipeline(self):
return self.controller
def tick(self, obs: dict, *, debug: bool | None = None, use_joystick: bool = True) -> dict:
if not obs:
self.step += 1
return {}
if use_joystick:
process_joystick(obs, self.ms, self.controller)
clamp_mode_params(self.ms)
if self.step > 0:
self.replan_timer += CONTROL_DT
if should_replan_request(self.ms, self.last_ms, self.replan_timer, self.step):
self.planner.request_replan(self.controller.ref_cursor, self.ms)
self.replan_timer = 0.0
self.ms.needs_replan = False
self.last_ms = _snapshot_ms(self.ms)
do_enc = self.step % ENCODER_UPDATE_EVERY == 0
if debug is None:
debug = self.step % DEBUG_PRINT_EVERY == 0
action = self.controller.step(obs, update_encoder=do_enc, debug=debug)
result = self.planner.try_get_new_motion()
if result:
self.controller.blend_new_motion(*result)
self.controller.advance_cursor()
self.step += 1
return action
def reset(self):
self.ms = MovementState()
self.controller.reinit_heading = True
self.controller.playing = True
self.step = 0
self.replan_timer = 0.0
self.last_ms = _snapshot_ms(self.ms)
def shutdown(self):
self.planner.stop_subprocess()
class SonicWholeBodyController:
"""Full-body SONIC controller for UnitreeG1's background controller thread."""
control_dt = CONTROL_DT
full_body = True
def __init__(self, force_cpu: bool = False):
logger.info("Loading SONIC whole-body controller...")
self._runtime = SonicRuntime(force_cpu=force_cpu)
self.kp = self._runtime.kp
self.kd = self._runtime.kd
self.controller = self._runtime.controller
self.ms = self._runtime.ms
logger.info(
"SONIC ready: %s (default mode: %s)",
MOTION_SETS[0][0],
LM(self.ms.mode).name,
)
def run_step(self, action: dict, lowstate) -> dict:
if lowstate is None:
return {}
obs = lowstate_to_obs(lowstate)
return self._runtime.tick(obs, debug=False)
def reset(self):
self._runtime.reset()
def shutdown(self):
self._runtime.shutdown()
+2 -3
View File
@@ -68,9 +68,8 @@ def make_locomotion_controller(name: str | None):
if name is None:
return None
controllers = {
"GrootLocomotionController": "lerobot.robots.unitree_g1.controllers.gr00t_locomotion",
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.controllers.holosoma_locomotion",
"SonicWholeBodyController": "lerobot.robots.unitree_g1.controllers.sonic_whole_body",
"GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion",
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion",
}
module_path = controllers.get(name)
if module_path is None:
@@ -21,7 +21,7 @@ import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.g1_utils import (
from .g1_utils import (
REMOTE_AXES,
REMOTE_BUTTONS,
G1_29_JointIndex,
@@ -22,7 +22,7 @@ import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.g1_utils import (
from .g1_utils import (
REMOTE_AXES,
G1_29_JointArmIndex,
G1_29_JointIndex,
+1 -9
View File
@@ -338,9 +338,6 @@ class UnitreeG1(Robot):
self.kp = np.array(self.config.kp, dtype=np.float32)
self.kd = np.array(self.config.kd, dtype=np.float32)
if self.controller is not None and hasattr(self.controller, "kp"):
self.kp = np.array(self.controller.kp, dtype=np.float32)
self.kd = np.array(self.controller.kd, dtype=np.float32)
for joint in G1_29_JointIndex:
self.msg.motor_cmd[joint].mode = 1
@@ -377,9 +374,6 @@ class UnitreeG1(Robot):
# Signal thread to stop and unblock any waits
self._shutdown_event.set()
if self.controller is not None and hasattr(self.controller, "shutdown"):
self.controller.shutdown()
# Wait for subscribe thread to finish
if self.subscribe_thread is not None:
self.subscribe_thread.join(timeout=2.0)
@@ -471,11 +465,9 @@ class UnitreeG1(Robot):
def send_action(self, action: RobotAction) -> RobotAction:
action_to_publish = action
if self.controller is not None:
self._update_controller_action(action)
if getattr(self.controller, "full_body", False):
return action
# Controller thread owns legs/waist. Here we only update joystick inputs
# and publish arm targets from the teleoperator.
self._update_controller_action(action)
arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex)
action_to_publish = {
key: value
-1
View File
@@ -54,7 +54,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
@@ -57,7 +57,6 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
-1
View File
@@ -137,7 +137,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
-1
View File
@@ -174,7 +174,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
@@ -41,7 +41,6 @@ from lerobot.robots import ( # noqa: F401
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
koch_leader,
@@ -89,7 +89,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
from ..teleoperator import Teleoperator
@@ -28,7 +27,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
logger = logging.getLogger(__name__)
class BiOpenArmLeader(BimanualMixin, Teleoperator):
class BiOpenArmLeader(Teleoperator):
"""
Bimanual OpenArm Leader Arms
"""
@@ -87,6 +86,27 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -109,3 +129,8 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
@dataclass
class BiOpenArmLeaderConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Leader teleoperators."""
"""Configuration class for Bi OpenArm Follower robots."""
left_arm_config: OpenArmLeaderConfigBase
right_arm_config: OpenArmLeaderConfigBase
@@ -1,20 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_mini import BiOpenArmMini
from .config_bi_openarm_mini import BiOpenArmMiniConfig
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
@@ -1,101 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
from ..teleoperator import Teleoperator
from .config_bi_openarm_mini import BiOpenArmMiniConfig
logger = logging.getLogger(__name__)
class BiOpenArmMini(BimanualMixin, Teleoperator):
"""Bimanual OpenArm Mini teleoperator.
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
bimanual leader can teleoperate a bimanual OpenArm follower.
"""
config_class = BiOpenArmMiniConfig
name = "bi_openarm_mini"
def __init__(self, config: BiOpenArmMiniConfig):
super().__init__(config)
self.config = config
# `side` is forced to match left/right regardless of what the user passed
# on the per-arm base config — the bimanual wrapper owns the side semantics.
left_arm_config = OpenArmMiniConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
side="left",
use_degrees=config.left_arm_config.use_degrees,
)
right_arm_config = OpenArmMiniConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
side="right",
use_degrees=config.right_arm_config.use_degrees,
)
self.left_arm = OpenArmMini(left_arm_config)
self.right_arm = OpenArmMini(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
}
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
action: RobotAction = {}
for k, v in self.left_arm.get_action().items():
action[f"left_{k}"] = v
for k, v in self.right_arm.get_action().items():
action[f"right_{k}"] = v
return action
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
if left_fb:
self.left_arm.send_feedback(left_fb)
if right_fb:
self.right_arm.send_feedback(right_fb)
@@ -1,29 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
from ..openarm_mini import OpenArmMiniConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
@dataclass
class BiOpenArmMiniConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Mini teleoperators."""
left_arm_config: OpenArmMiniConfigBase
right_arm_config: OpenArmMiniConfigBase
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_rebot_102_leader import BiRebot102Leader
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
from .bi_rebot_102_leader import BiRebotArm102Leader
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
@@ -18,17 +18,16 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
from ..teleoperator import Teleoperator
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
logger = logging.getLogger(__name__)
class BiRebot102Leader(BimanualMixin, Teleoperator):
class BiRebotArm102Leader(Teleoperator):
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
@@ -36,10 +35,10 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
leader can teleoperate a bimanual reBot B601 follower.
"""
config_class = BiRebot102LeaderConfig
config_class = BiRebotArm102LeaderConfig
name = "bi_rebot_102_leader"
def __init__(self, config: BiRebot102LeaderConfig):
def __init__(self, config: BiRebotArm102LeaderConfig):
super().__init__(config)
self.config = config
@@ -77,6 +76,27 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_action(self) -> RobotAction:
action_dict = {}
@@ -86,3 +106,8 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
@dataclass
class BiRebot102LeaderConfig(TeleoperatorConfig):
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
left_arm_config: RebotArm102LeaderConfig
@@ -17,9 +17,7 @@
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..so_leader import SOLeader, SOLeaderTeleopConfig
from ..teleoperator import Teleoperator
@@ -28,7 +26,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
logger = logging.getLogger(__name__)
class BiSOLeader(BimanualMixin, Teleoperator):
class BiSOLeader(Teleoperator):
"""
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -69,12 +67,33 @@ class BiSOLeader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
def get_action(self) -> dict[str, float]:
action_dict = {}
# Add "left_" prefix
@@ -90,3 +109,8 @@ class BiSOLeader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
from .config_openarm_mini import OpenArmMiniConfig
from .openarm_mini import OpenArmMini
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
@@ -19,21 +19,12 @@ from dataclasses import dataclass
from ..config import TeleoperatorConfig
@dataclass
class OpenArmMiniConfigBase:
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
port: str
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
# during readout. If `None`, no flipping is applied.
side: str | None = None
use_degrees: bool = True
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
pass
class OpenArmMiniConfig(TeleoperatorConfig):
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
port_right: str = "/dev/ttyUSB0"
port_left: str = "/dev/ttyUSB1"
use_degrees: bool = True
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Per-side motor direction flips applied during readout.
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
}
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
# Leader joint 6 follower joint 7 (symmetric — its own inverse).
# Leader joint 6 maps to follower joint 7 and vice versa
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator):
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
"""
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
"""
config_class = OpenArmMiniConfig
@@ -56,12 +56,9 @@ class OpenArmMini(Teleoperator):
super().__init__(config)
self.config = config
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
norm_mode_body = MotorNormMode.DEGREES
motors = {
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
@@ -72,15 +69,46 @@ class OpenArmMini(Teleoperator):
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
self.bus = FeetechMotorsBus(
port=self.config.port,
motors=motors,
calibration=self.calibration,
motors_left = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
)
@property
def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property
def feedback_features(self) -> dict[str, type]:
@@ -88,12 +116,14 @@ class OpenArmMini(Teleoperator):
@property
def is_connected(self) -> bool:
return self.bus.is_connected
return self.bus_right.is_connected and self.bus_left.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting arm on {self.config.port}...")
self.bus.connect()
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate:
self.calibrate()
@@ -103,14 +133,14 @@ class OpenArmMini(Teleoperator):
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for a single OpenArm Mini arm.
Run calibration procedure for OpenArm Mini.
1. Disable torque
2. Ask user to position arm in hanging position with gripper closed
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions)
5. Save calibration
@@ -122,51 +152,70 @@ class OpenArmMini(Teleoperator):
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
self.bus.write_calibration(self.calibration)
cal_right = {
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return
logger.info(f"\nRunning calibration for {self}")
self.bus.disable_torque()
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
logger.info("Setting Phase to 12 for all motors...")
for motor in self.bus.motors:
self.bus.write("Phase", motor, 12)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
"""Calibrate a single arm with Feetech motors."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input(
"\nCalibration: Zero Position\n"
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
homing_offsets = self.bus.set_half_turn_homings()
logger.info("Arm zero position set.")
homing_offsets = bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
print("\nSetting motor ranges\n")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
if self.calibration is None:
self.calibration = {}
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
max_res = motor_resolution - 1
for motor_name, motor in self.bus.motors.items():
for motor_name, motor in bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper":
input(
"\nGripper Calibration\n"
"Step 1: CLOSE the gripper fully\n"
"Press ENTER when gripper is closed..."
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
f"Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..."
)
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
open_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos:
@@ -179,16 +228,16 @@ class OpenArmMini(Teleoperator):
drive_mode = 1
logger.info(
f" {motor_name}: range set to [{range_min}, {range_max}] "
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})"
)
else:
range_min = 0
range_max = max_res
drive_mode = 0
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[motor_name] = MotorCalibration(
self.calibration[prefixed_name] = MotorCalibration(
id=motor.id,
drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name],
@@ -196,68 +245,108 @@ class OpenArmMini(Teleoperator):
range_max=range_max,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
cal_for_bus = {
k.replace(f"{arm_name}_", ""): v
for k, v in self.calibration.items()
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None:
self.bus.disable_torque()
self.bus.configure_motors()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_right.disable_torque()
self.bus_right.configure_motors()
for motor in self.bus_right.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
print("\nSetting up RIGHT arm motors...")
for motor in reversed(self.bus_right.motors):
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
self.bus_right.setup_motor(motor)
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> RobotAction:
"""Get current action (read positions from all motors)."""
"""Get current action from both arms (read positions from all motors)."""
start = time.perf_counter()
positions = self.bus.sync_read("Present_Position")
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
# Per-side direction flip is applied based on the configured `side`.
action: dict[str, Any] = {}
for motor, val in positions.items():
for motor, val in right_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def enable_torque(self) -> None:
self.bus.enable_torque()
"""Enable torque on both arms for position control."""
self.bus_right.enable_torque()
self.bus_left.enable_torque()
def disable_torque(self) -> None:
self.bus.disable_torque()
"""Disable torque on both arms for free movement."""
self.bus_right.disable_torque()
self.bus_left.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
goals: dict[str, float] = {}
right_goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
for key, val in positions.items():
if not key.endswith(".pos"):
continue
base = key.removesuffix(".pos")
# JOINT_REMAP is symmetric (its own inverse).
target = JOINT_REMAP.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
goals[target] = -val if target in self._motors_to_flip else val
motor_name = key.removesuffix(".pos")
if motor_name.startswith("right_"):
base = motor_name.removeprefix("right_")
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
if goals:
self.bus.sync_write("Goal_Position", goals)
if right_goals:
self.bus_right.sync_write("Goal_Position", right_goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
@@ -265,5 +354,6 @@ class OpenArmMini(Teleoperator):
@check_if_not_connected
def disconnect(self) -> None:
self.bus.disconnect()
self.bus_right.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.")
+2 -6
View File
@@ -99,18 +99,14 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .openarm_mini import OpenArmMini
return OpenArmMini(config)
elif config.type == "bi_openarm_mini":
from .bi_openarm_mini import BiOpenArmMini
return BiOpenArmMini(config)
elif config.type == "rebot_102_leader":
from .rebot_102_leader import RebotArm102Leader
return RebotArm102Leader(config)
elif config.type == "bi_rebot_102_leader":
from .bi_rebot_102_leader import BiRebot102Leader
from .bi_rebot_102_leader import BiRebotArm102Leader
return BiRebot102Leader(config)
return BiRebotArm102Leader(config)
else:
try:
return cast("Teleoperator", make_device_from_device_class(config))
-63
View File
@@ -1,63 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
class BimanualMixin:
"""Lifecycle delegation for bimanual robots and teleoperators.
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
their own ``__init__``. They retain ownership of feature dicts and the
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
``send_feedback``), which vary per-embodiment.
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
take precedence in the MRO::
class BiFooFollower(BimanualMixin, Robot): ...
"""
left_arm: Any
right_arm: Any
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
+8 -2
View File
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
This function uses `importlib.metadata` to find packages installed in the environment
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
"""
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
prefixes = (
"lerobot_robot_",
"lerobot_camera_",
"lerobot_teleoperator_",
"lerobot_policy_",
"lerobot_env_",
)
imported: list[str] = []
failed: list[str] = []
-73
View File
@@ -28,7 +28,6 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
@@ -345,78 +344,6 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
assert len(dataset) == 24
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
for ep, length in enumerate(episode_lengths):
episode_index += [ep] * length
frame_index += list(range(length))
timestamp += [round(i / fps, 6) for i in range(length)]
task_index += [0] * length
subtask_index += [0] * length # legacy column the writer must drop
pd.DataFrame(
{
"episode_index": episode_index,
"frame_index": frame_index,
"timestamp": timestamp,
"task_index": task_index,
"subtask_index": subtask_index,
}
).to_parquet(data_dir / "file-000.parquet", index=False)
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
write_tasks(tasks_df, root)
write_json(
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
root / "meta" / "info.json",
)
return root
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
"""Rewriting a packed shard must keep one row group per episode, not collapse
every episode into a single giant row group."""
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
shard = root / "data" / "chunk-000" / "file-000.parquet"
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
staging_dir = tmp_path / "stage"
for ep in range(len(episode_lengths)):
_stage_episode(
staging_dir,
ep,
plan=[
{
"role": "assistant",
"content": f"subtask for ep {ep}",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
)
records = list(iter_episodes(root))
LanguageColumnsWriter().write_all(records, staging_dir, root)
# One row group per episode, with row counts matching the episode lengths.
md = pq.ParquetFile(shard).metadata
assert md.num_row_groups == len(episode_lengths)
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
# Language columns are still present after the per-episode rewrite.
table = pq.read_table(shard)
assert "language_persistent" in table.column_names
assert "language_events" in table.column_names
def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant"
-55
View File
@@ -32,26 +32,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def assert_data_shards_one_row_group_per_episode(root):
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
import pyarrow.parquet as pq
shards = sorted((root / "data").rglob("*.parquet"))
assert shards, f"no data shards found under {root}/data"
n_episodes = 0
for shard in shards:
pf = pq.ParquetFile(shard)
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
assert pf.metadata.num_row_groups == len(set(episodes)), shard
for i in range(pf.metadata.num_row_groups):
rg_episodes = set(
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
)
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
n_episodes += len(set(episodes))
return n_episodes
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, (
@@ -586,41 +566,6 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
)
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
Covers both the non-image (``df.to_parquet``) and image
(``to_parquet_with_hf_images``) write branches, including the merge-into-
existing-file branch via a low file-size threshold that forces packing.
"""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "rg_0",
repo_id=f"{DUMMY_REPO_ID}_rg_0",
total_episodes=3,
total_frames=60,
use_videos=use_videos,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "rg_1",
repo_id=f"{DUMMY_REPO_ID}_rg_1",
total_episodes=4,
total_frames=80,
use_videos=use_videos,
)
aggr_root = tmp_path / "rg_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
aggr_root=aggr_root,
)
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
+3 -21
View File
@@ -2370,32 +2370,14 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # images kept, stored as "image" dtype
use_videos=False, # expect "image" dtype
patterns=None,
)
key = f"{OBS_IMAGES}.back"
key_front = f"{OBS_IMAGES}.front"
assert key in out
assert key_front in out
assert out[key]["dtype"] == "image"
assert out[key_front]["dtype"] == "image"
assert out[key]["shape"] == initial["back"]
def test_aggregate_images_excluded():
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
initial = {"back": (480, 640, 3)}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
exclude_images=True,
patterns=None,
)
assert f"{OBS_IMAGES}.back" not in out
assert f"{OBS_IMAGES}.front" not in out
assert key not in out
assert key_front not in out
def test_aggregate_images_when_use_videos_true():
+3 -3
View File
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
import pytest
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
from lerobot.teleoperators.rebot_102_leader import (
RebotArm102Leader,
RebotArm102LeaderConfig,
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebot102LeaderConfig(
cfg = BiRebotArm102LeaderConfig(
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
)
teleop = BiRebot102Leader(cfg)
teleop = BiRebotArm102Leader(cfg)
assert any(k.startswith("left_") for k in teleop.action_features)
assert any(k.startswith("right_") for k in teleop.action_features)
assert "left_gripper.pos" in teleop.action_features
Generated
+900 -949
View File
File diff suppressed because it is too large Load Diff