Compare commits

...

4 Commits

Author SHA1 Message Date
Khalil Meftah 5a2c0369a2 fix(robots): restore joint clipping and wrist_yaw fallback in ReBot B601 send_action 2026-06-11 19:58:55 +02:00
Khalil Meftah 96445c52fc feat(config): add MIT control mode ReBot
- Add configurable arm control mode (mit default, pos_vel fallback) with tunable mit_kp / mit_kd
- Add optional gripper control mode (force_pos default, mit optional) with gripper_mit_kp / gripper_mit_kd
- Update tests for MIT arm routing, gripper mode routing, and revised joint limits
2026-06-11 19:08:33 +02:00
Khalil Meftah c5cfd29275 fix(config): update joint limits for RebotB601Follower and RebotArm102Leader 2026-06-11 18:53:49 +02:00
Pepijn 41166b39fb fix(train): synchronize EpisodeAwareSampler shuffling across ranks and gate dataset download per node (#3768)
* fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync

In distributed training, accelerate can only synchronize the shuffle
permutation across ranks when the sampler exposes a generator attribute.
EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch
shards relied on every rank's global CPU RNG staying in lockstep forever;
any rank-asymmetric RNG consumption (e.g. eval rollouts on the main
process only) silently desynced the permutations and ranks trained on
overlapping/missing samples.

* fix(train): seed sampler generator and gate dataset download per node

- Pass a generator seeded with cfg.seed to EpisodeAwareSampler so
  accelerator.prepare registers it as the synchronized RNG and the
  shuffle order is reproducible.
- Gate the initial make_dataset call on is_local_main_process instead of
  is_main_process: the global main process only exists on node 0, so on
  every other node all local ranks were downloading the dataset and
  building the Arrow cache concurrently.
2026-06-11 11:07:42 +02:00
8 changed files with 138 additions and 28 deletions
+7 -1
View File
@@ -30,6 +30,7 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
):
"""Sampler that optionally incorporates episode boundary information.
@@ -41,6 +42,10 @@ class EpisodeAwareSampler:
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
@@ -73,10 +78,11 @@ class EpisodeAwareSampler:
self.indices = indices
self.shuffle = shuffle
self.generator = generator
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
else:
for i in self.indices:
@@ -52,7 +52,13 @@ class BiRebotB601Follower(Robot):
cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
control_mode=config.left_arm_config.control_mode,
mit_kp=config.left_arm_config.mit_kp,
mit_kd=config.left_arm_config.mit_kd,
gripper_control_mode=config.left_arm_config.gripper_control_mode,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.left_arm_config.gripper_mit_kp,
gripper_mit_kd=config.left_arm_config.gripper_mit_kd,
joint_limits=config.left_arm_config.joint_limits,
)
@@ -67,7 +73,13 @@ class BiRebotB601Follower(Robot):
cameras=config.right_arm_config.cameras,
motor_can_ids=config.right_arm_config.motor_can_ids,
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
control_mode=config.right_arm_config.control_mode,
mit_kp=config.right_arm_config.mit_kp,
mit_kd=config.right_arm_config.mit_kd,
gripper_control_mode=config.right_arm_config.gripper_control_mode,
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.right_arm_config.gripper_mit_kp,
gripper_mit_kd=config.right_arm_config.gripper_mit_kd,
joint_limits=config.right_arm_config.joint_limits,
)
@@ -65,18 +65,33 @@ class RebotB601FollowerConfig:
}
)
# Target velocity for joints running in POS_VEL mode, in degrees/s. A scalar is
# applied to every joint; a list provides one value per joint (in motor order).
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
# Max speed (deg/s) per joint for POS_VEL arms and FORCE_POS gripper (motor order).
pos_vel_velocity: float | list[float] = field(
default_factory=lambda: [150.0, 150.0, 150.0, 150.0, 150.0, 150.0, 500.0]
)
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
gripper_torque_ratio: float = 0.1
# Arm control: "mit" or "pos_vel".
control_mode: str = "mit"
# MIT kp/kd per arm joint (motor order). Unused when control_mode="pos_vel".
mit_kp: float | list[float] = field(default_factory=lambda: [45.0, 45.0, 45.0, 8.0, 9.0, 8.0, 8.0])
mit_kd: float | list[float] = field(default_factory=lambda: [12.0, 12.0, 12.0, 1.0, 1.0, 1.0, 1.0])
# Gripper control: "force_pos" or "mit".
gripper_control_mode: str = "force_pos"
# FORCE_POS only: max grip force, in [0, 1].
gripper_torque_ratio: float = 0.05
# MIT only.
gripper_mit_kp: float = 8.0
gripper_mit_kd: float = 0.3
# Soft joint limits (degrees). These are clipped against on every action.
joint_limits: dict[str, tuple[float, float]] = field(
default_factory=lambda: {
"shoulder_pan": (-145.0, 145.0),
"shoulder_lift": (-170.0, 1.0),
"shoulder_pan": (-150.0, 150.0),
"shoulder_lift": (-200.0, 1.0),
"elbow_flex": (-200.0, 1.0),
"wrist_flex": (-80.0, 90.0),
"wrist_yaw": (-90.0, 90.0),
@@ -169,11 +169,25 @@ class RebotB601Follower(Robot):
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
if self.config.control_mode not in ("pos_vel", "mit"):
raise ValueError(
f"Unsupported control_mode '{self.config.control_mode}'. Use 'pos_vel' or 'mit'."
)
if self.config.gripper_control_mode not in ("force_pos", "mit"):
raise ValueError(
f"Unsupported gripper_control_mode '{self.config.gripper_control_mode}'. "
"Use 'force_pos' or 'mit'."
)
use_mit = self.config.control_mode == "mit"
gripper_use_mit = self.config.gripper_control_mode == "mit"
self.bus.enable_all()
for motor_name, motor in self.motors.items():
target_mode = (
MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else MotorBridgeMode.POS_VEL
)
if motor_name == GRIPPER_MOTOR:
target_mode = MotorBridgeMode.MIT if gripper_use_mit else MotorBridgeMode.FORCE_POS
elif use_mit:
target_mode = MotorBridgeMode.MIT
else:
target_mode = MotorBridgeMode.POS_VEL
for attempt in range(_ENSURE_MODE_RETRIES + 1):
try:
motor.ensure_mode(target_mode)
@@ -252,22 +266,34 @@ class RebotB601Follower(Robot):
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
use_mit = self.config.control_mode == "mit"
for motor_name, position_deg in goal_pos.items():
motor = self.motors.get(motor_name)
if motor is None:
continue
idx = self.motor_names.index(motor_name)
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
pos_rad = math.radians(position_deg)
vel_rad = math.radians(vel_deg_s)
if motor_name == GRIPPER_MOTOR:
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
if self.config.gripper_control_mode == "mit":
motor.send_mit(pos_rad, 0.0, self.config.gripper_mit_kp, self.config.gripper_mit_kd, 0.0)
else:
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_force_pos(pos_rad, math.radians(vel_deg_s), self.config.gripper_torque_ratio)
elif use_mit:
kp = self.config.mit_kp[idx] if isinstance(self.config.mit_kp, list) else self.config.mit_kp
kd = self.config.mit_kd[idx] if isinstance(self.config.mit_kd, list) else self.config.mit_kd
motor.send_mit(pos_rad, 0.0, kp, kd, 0.0)
else:
motor.send_pos_vel(pos_rad, vel_rad)
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_pos_vel(pos_rad, math.radians(vel_deg_s))
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
+15 -5
View File
@@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset
if not is_main_process:
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -386,12 +389,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
generator=sampler_generator,
)
else:
shuffle = True
@@ -65,7 +65,7 @@ class RebotArm102LeaderConfig:
joint_ranges: dict[str, list[int]] = field(
default_factory=lambda: {
"shoulder_pan": [-150, 150],
"shoulder_lift": [-170, 1],
"shoulder_lift": [-200, 1],
"elbow_flex": [-200, 1],
"wrist_flex": [-80, 90],
"wrist_yaw": [-90, 90],
+24
View File
@@ -114,6 +114,30 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
+20 -3
View File
@@ -91,10 +91,11 @@ def test_get_observation_converts_to_degrees(follower):
def test_send_action_clips_to_joint_limits(follower):
# shoulder_pan limit is (-145, 145); request beyond the upper bound.
# shoulder_pan limit is (-150, 150); request beyond the upper bound.
returned = follower.send_action({"shoulder_pan.pos": 999.0})
assert returned["shoulder_pan.pos"] == 145.0
follower.motors["shoulder_pan"].send_pos_vel.assert_called_once()
assert returned["shoulder_pan.pos"] == 150.0
# Default control_mode is "mit", so arm joints are driven via send_mit.
follower.motors["shoulder_pan"].send_mit.assert_called_once()
def test_send_action_routes_gripper_to_force_pos(follower):
@@ -103,6 +104,22 @@ def test_send_action_routes_gripper_to_force_pos(follower):
follower.motors["gripper"].send_pos_vel.assert_not_called()
def test_gripper_mit_mode_routes_to_send_mit():
bus_mock = _make_bus_mock()
with (
patch(f"{_MODULE}.require_package", lambda *a, **kw: None),
patch(f"{_MODULE}.MotorBridgeController") as controller_cls,
patch(f"{_MODULE}.MotorBridgeMode", MagicMock()),
):
controller_cls.from_dm_serial.return_value = bus_mock
cfg = RebotB601FollowerRobotConfig(port="/dev/null", gripper_control_mode="mit")
robot = RebotB601Follower(cfg)
robot.connect(calibrate=False)
robot.send_action({"gripper.pos": -10.0})
robot.motors["gripper"].send_mit.assert_called_once()
robot.motors["gripper"].send_force_pos.assert_not_called()
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebotB601FollowerConfig(