Compare commits

..

3 Commits

Author SHA1 Message Date
github-actions[bot] 9b4516ef63 chore(dependencies): update uv.lock 2026-07-01 05:04:59 +00:00
Caroline Pascal 8414188db0 fix(datasets dependency): removing datasets dependency in pretrained.py (#3897) 2026-06-30 20:21:06 +02:00
Khalil Meftah 0da98afd63 Feat(robot): add MIT control mode to ReBot (#3778)
* fix(config): update joint limits for RebotB601Follower and RebotArm102Leader

* feat(config): add MIT control mode ReBot

- Add configurable arm control mode (mit default, pos_vel fallback) with tunable mit_kp / mit_kd
- Add optional gripper control mode (force_pos default, mit optional) with gripper_mit_kp / gripper_mit_kd
- Update tests for MIT arm routing, gripper mode routing, and revised joint limits

* fix(robots): restore joint clipping and wrist_yaw fallback in ReBot B601 send_action

* feat(robot): increase gripper velocity and torque for rebot arm
2026-06-30 17:17:50 +02:00
10 changed files with 873 additions and 892 deletions
+1 -9
View File
@@ -36,9 +36,7 @@ HW_VIDEO_CODECS = [
"h264_vaapi", # Linux Intel/AMD
"h264_qsv", # Intel Quick Sync
]
VALID_VIDEO_CODECS: frozenset[str] = frozenset(
{"h264", "hevc", "libsvtav1", "libaom-av1", "auto", *HW_VIDEO_CODECS}
)
VALID_VIDEO_CODECS: frozenset[str] = frozenset({"h264", "hevc", "libsvtav1", "auto", *HW_VIDEO_CODECS})
# Aliases for legacy video codec names.
VIDEO_CODECS_ALIASES: dict[str, str] = {"av1": "libsvtav1"}
@@ -222,12 +220,6 @@ class VideoEncoderConfig:
if self.fast_decode:
opts["tune"] = "fastdecode"
set_if("threads", encoder_threads)
elif self.vcodec == "libaom-av1":
set_if("crf", self.crf)
set_if("preset", self.preset)
if encoder_threads is not None:
opts["threads"] = encoder_threads
opts["row-mt"] = 1
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
if self.crf is not None:
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
+27 -29
View File
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import builtins
import dataclasses
@@ -19,7 +21,7 @@ import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TypedDict, TypeVar, Unpack
from typing import TYPE_CHECKING, TypedDict, TypeVar, Unpack
import packaging
import safetensors
@@ -38,10 +40,13 @@ from .utils import log_model_loading_keys
T = TypeVar("T", bound="PreTrainedPolicy")
if TYPE_CHECKING:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
def _build_card_context(
cfg: TrainPipelineConfig | None,
dataset_repo_id: str | None,
dataset_meta: LeRobotDatasetMetadata | None,
input_features: dict | None,
output_features: dict | None,
) -> dict:
@@ -72,30 +77,16 @@ def _build_card_context(
"lerobot_version": __version__,
}
if dataset_repo_id:
dataset_cfg = getattr(cfg, "dataset", None)
try:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(
dataset_repo_id,
root=getattr(dataset_cfg, "root", None),
revision=getattr(dataset_cfg, "revision", None),
)
context["dataset"] = {
"repo_id": dataset_repo_id,
"episodes": meta.total_episodes,
"frames": meta.total_frames,
"fps": meta.fps,
"tasks": [str(task) for task in meta.tasks.index],
}
context["robot_type"] = meta.robot_type
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
logging.warning(
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
f"omitted from the model card. ({e})"
)
if dataset_meta is not None:
context["dataset"] = {
"repo_id": dataset_meta.repo_id,
"episodes": dataset_meta.total_episodes,
"frames": dataset_meta.total_frames,
"fps": dataset_meta.fps,
"tasks": [str(task) for task in dataset_meta.tasks.index],
}
context["robot_type"] = dataset_meta.robot_type
context["cameras"] = [key.split(".")[-1] for key in dataset_meta.camera_keys]
return context
@@ -304,6 +295,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
cfg: TrainPipelineConfig,
peft_model=None,
state_dict: dict[str, Tensor] | None = None,
dataset_meta: LeRobotDatasetMetadata | None = None,
):
api = HfApi()
repo_id = api.create_repo(
@@ -325,7 +317,12 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self.save_pretrained(saved_path, state_dict=state_dict)
card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
cfg.dataset.repo_id,
self.config.type,
self.config.license,
self.config.tags,
cfg=cfg,
dataset_meta=dataset_meta,
)
card.save(str(saved_path / "README.md"))
@@ -352,6 +349,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
license: str | None,
tags: list[str] | None,
cfg: TrainPipelineConfig | None = None,
dataset_meta: LeRobotDatasetMetadata | None = None,
) -> ModelCard:
base_model_mapping = {
"smolvla": "lerobot/smolvla_base",
@@ -372,7 +370,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
)
context = _build_card_context(
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
cfg, dataset_meta, self.config.input_features, self.config.output_features
)
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
@@ -389,7 +387,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self,
peft_config=None,
peft_cli_overrides: dict | None = None,
) -> "PreTrainedPolicy":
) -> PreTrainedPolicy:
"""
Wrap this policy with PEFT adapters for parameter-efficient fine-tuning.
@@ -65,7 +65,13 @@ class BiRebotB601Follower(BimanualMixin, Robot):
cameras=left_arm_cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
control_mode=config.left_arm_config.control_mode,
mit_kp=config.left_arm_config.mit_kp,
mit_kd=config.left_arm_config.mit_kd,
gripper_control_mode=config.left_arm_config.gripper_control_mode,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.left_arm_config.gripper_mit_kp,
gripper_mit_kd=config.left_arm_config.gripper_mit_kd,
joint_limits=config.left_arm_config.joint_limits,
)
@@ -80,7 +86,13 @@ class BiRebotB601Follower(BimanualMixin, Robot):
cameras=config.right_arm_config.cameras,
motor_can_ids=config.right_arm_config.motor_can_ids,
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
control_mode=config.right_arm_config.control_mode,
mit_kp=config.right_arm_config.mit_kp,
mit_kd=config.right_arm_config.mit_kd,
gripper_control_mode=config.right_arm_config.gripper_control_mode,
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.right_arm_config.gripper_mit_kp,
gripper_mit_kd=config.right_arm_config.gripper_mit_kd,
joint_limits=config.right_arm_config.joint_limits,
)
@@ -65,18 +65,33 @@ class RebotB601FollowerConfig:
}
)
# Target velocity for joints running in POS_VEL mode, in degrees/s. A scalar is
# applied to every joint; a list provides one value per joint (in motor order).
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
# Max speed (deg/s) per joint for POS_VEL arms and FORCE_POS gripper (motor order).
pos_vel_velocity: float | list[float] = field(
default_factory=lambda: [150.0, 150.0, 150.0, 150.0, 150.0, 150.0, 900.0]
)
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
gripper_torque_ratio: float = 0.1
# Arm control: "mit" or "pos_vel".
control_mode: str = "mit"
# MIT kp/kd per arm joint (motor order). Unused when control_mode="pos_vel".
mit_kp: float | list[float] = field(default_factory=lambda: [45.0, 45.0, 45.0, 8.0, 9.0, 8.0, 8.0])
mit_kd: float | list[float] = field(default_factory=lambda: [12.0, 12.0, 12.0, 1.0, 1.0, 1.0, 1.0])
# Gripper control: "force_pos" or "mit".
gripper_control_mode: str = "force_pos"
# FORCE_POS only: max grip force, in [0, 1].
gripper_torque_ratio: float = 0.07
# MIT only.
gripper_mit_kp: float = 8.0
gripper_mit_kd: float = 0.3
# Soft joint limits (degrees). These are clipped against on every action.
joint_limits: dict[str, tuple[float, float]] = field(
default_factory=lambda: {
"shoulder_pan": (-145.0, 145.0),
"shoulder_lift": (-170.0, 1.0),
"shoulder_pan": (-150.0, 150.0),
"shoulder_lift": (-200.0, 1.0),
"elbow_flex": (-200.0, 1.0),
"wrist_flex": (-80.0, 90.0),
"wrist_yaw": (-90.0, 90.0),
@@ -174,11 +174,25 @@ class RebotB601Follower(Robot):
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
if self.config.control_mode not in ("pos_vel", "mit"):
raise ValueError(
f"Unsupported control_mode '{self.config.control_mode}'. Use 'pos_vel' or 'mit'."
)
if self.config.gripper_control_mode not in ("force_pos", "mit"):
raise ValueError(
f"Unsupported gripper_control_mode '{self.config.gripper_control_mode}'. "
"Use 'force_pos' or 'mit'."
)
use_mit = self.config.control_mode == "mit"
gripper_use_mit = self.config.gripper_control_mode == "mit"
self.bus.enable_all()
for motor_name, motor in self.motors.items():
target_mode = (
MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else MotorBridgeMode.POS_VEL
)
if motor_name == GRIPPER_MOTOR:
target_mode = MotorBridgeMode.MIT if gripper_use_mit else MotorBridgeMode.FORCE_POS
elif use_mit:
target_mode = MotorBridgeMode.MIT
else:
target_mode = MotorBridgeMode.POS_VEL
for attempt in range(_ENSURE_MODE_RETRIES + 1):
try:
motor.ensure_mode(target_mode)
@@ -264,22 +278,34 @@ class RebotB601Follower(Robot):
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
use_mit = self.config.control_mode == "mit"
for motor_name, position_deg in goal_pos.items():
motor = self.motors.get(motor_name)
if motor is None:
continue
idx = self.motor_names.index(motor_name)
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
pos_rad = math.radians(position_deg)
vel_rad = math.radians(vel_deg_s)
if motor_name == GRIPPER_MOTOR:
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
if self.config.gripper_control_mode == "mit":
motor.send_mit(pos_rad, 0.0, self.config.gripper_mit_kp, self.config.gripper_mit_kd, 0.0)
else:
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_force_pos(pos_rad, math.radians(vel_deg_s), self.config.gripper_torque_ratio)
elif use_mit:
kp = self.config.mit_kp[idx] if isinstance(self.config.mit_kp, list) else self.config.mit_kp
kd = self.config.mit_kd[idx] if isinstance(self.config.mit_kd, list) else self.config.mit_kd
motor.send_mit(pos_rad, 0.0, kp, kd, 0.0)
else:
motor.send_pos_vel(pos_rad, vel_rad)
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_pos_vel(pos_rad, math.radians(vel_deg_s))
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
+2 -2
View File
@@ -736,9 +736,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
unwrapped_model = accelerator.unwrap_model(policy)
# PEFT only applies when training a policy — reward models use the plain path.
if not cfg.is_reward_model_training and cfg.policy.use_peft:
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model, dataset_meta=dataset.meta)
else:
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict, dataset_meta=dataset.meta)
preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id)
@@ -65,7 +65,7 @@ class RebotArm102LeaderConfig:
joint_ranges: dict[str, list[int]] = field(
default_factory=lambda: {
"shoulder_pan": [-150, 150],
"shoulder_lift": [-170, 1],
"shoulder_lift": [-200, 1],
"elbow_flex": [-200, 1],
"wrist_flex": [-80, 90],
"wrist_yaw": [-90, 90],
+1 -2
View File
@@ -1531,7 +1531,6 @@ def test_valid_video_codecs_constant():
assert "h264" in VALID_VIDEO_CODECS
assert "hevc" in VALID_VIDEO_CODECS
assert "libsvtav1" in VALID_VIDEO_CODECS
assert "libaom-av1" in VALID_VIDEO_CODECS
assert "auto" in VALID_VIDEO_CODECS
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
assert "h264_nvenc" in VALID_VIDEO_CODECS
@@ -1539,7 +1538,7 @@ def test_valid_video_codecs_constant():
assert "h264_qsv" in VALID_VIDEO_CODECS
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
assert "hevc_nvenc" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 11
assert len(VALID_VIDEO_CODECS) == 10
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
+20 -3
View File
@@ -91,10 +91,11 @@ def test_get_observation_converts_to_degrees(follower):
def test_send_action_clips_to_joint_limits(follower):
# shoulder_pan limit is (-145, 145); request beyond the upper bound.
# shoulder_pan limit is (-150, 150); request beyond the upper bound.
returned = follower.send_action({"shoulder_pan.pos": 999.0})
assert returned["shoulder_pan.pos"] == 145.0
follower.motors["shoulder_pan"].send_pos_vel.assert_called_once()
assert returned["shoulder_pan.pos"] == 150.0
# Default control_mode is "mit", so arm joints are driven via send_mit.
follower.motors["shoulder_pan"].send_mit.assert_called_once()
def test_send_action_routes_gripper_to_force_pos(follower):
@@ -103,6 +104,22 @@ def test_send_action_routes_gripper_to_force_pos(follower):
follower.motors["gripper"].send_pos_vel.assert_not_called()
def test_gripper_mit_mode_routes_to_send_mit():
bus_mock = _make_bus_mock()
with (
patch(f"{_MODULE}.require_package", lambda *a, **kw: None),
patch(f"{_MODULE}.MotorBridgeController") as controller_cls,
patch(f"{_MODULE}.MotorBridgeMode", MagicMock()),
):
controller_cls.from_dm_serial.return_value = bus_mock
cfg = RebotB601FollowerRobotConfig(port="/dev/null", gripper_control_mode="mit")
robot = RebotB601Follower(cfg)
robot.connect(calibrate=False)
robot.send_action({"gripper.pos": -10.0})
robot.motors["gripper"].send_mit.assert_called_once()
robot.motors["gripper"].send_force_pos.assert_not_called()
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebotB601FollowerConfig(
Generated
+750 -828
View File
File diff suppressed because it is too large Load Diff