Compare commits

..

2 Commits

Author SHA1 Message Date
Maximellerbach 811727d462 renaming to return_intermediate_predictions 2026-06-10 13:50:59 +02:00
Maxime Ellerbach d1a8910f60 feat(policy): adding return_extra to policy contracts 2026-06-10 11:23:30 +00:00
9 changed files with 45 additions and 140 deletions
+1 -7
View File
@@ -30,7 +30,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
):
"""Sampler that optionally incorporates episode boundary information.
@@ -42,10 +41,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
@@ -78,11 +73,10 @@ class EpisodeAwareSampler:
self.indices = indices
self.shuffle = shuffle
self.generator = generator
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices), generator=self.generator):
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
+17 -2
View File
@@ -40,6 +40,7 @@ T = TypeVar("T", bound="PreTrainedPolicy")
class ActionSelectKwargs(TypedDict, total=False):
noise: Tensor | None
return_intermediate_predictions: bool
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
@@ -187,20 +188,34 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
def predict_action_chunk(
self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor | tuple[Tensor, dict[str, Tensor]]:
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
Child classes using action chunking should use this method within `select_action` to form the action chunk
cached for selection.
By default returns just the action `Tensor`. If `return_intermediate_predictions=True`,
returns `(action, predictions)` where `predictions` is a (possibly empty) `dict[str, Tensor]`
of additional model predictions a policy may expose (e.g. world-model predicted frames).
Policies that produce nothing extra may ignore the kwarg.
"""
raise NotImplementedError
@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
def select_action(
self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor | tuple[Tensor, dict[str, Tensor]]:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
By default returns just the action `Tensor`. If `return_intermediate_predictions=True`,
returns `(action, predictions)` where `predictions` is a (possibly empty) `dict[str, Tensor]`
of additional model predictions a policy may expose (e.g. world-model predicted frames).
Policies that produce nothing extra may ignore the kwarg.
"""
raise NotImplementedError
@@ -52,13 +52,7 @@ class BiRebotB601Follower(Robot):
cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
control_mode=config.left_arm_config.control_mode,
mit_kp=config.left_arm_config.mit_kp,
mit_kd=config.left_arm_config.mit_kd,
gripper_control_mode=config.left_arm_config.gripper_control_mode,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.left_arm_config.gripper_mit_kp,
gripper_mit_kd=config.left_arm_config.gripper_mit_kd,
joint_limits=config.left_arm_config.joint_limits,
)
@@ -73,13 +67,7 @@ class BiRebotB601Follower(Robot):
cameras=config.right_arm_config.cameras,
motor_can_ids=config.right_arm_config.motor_can_ids,
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
control_mode=config.right_arm_config.control_mode,
mit_kp=config.right_arm_config.mit_kp,
mit_kd=config.right_arm_config.mit_kd,
gripper_control_mode=config.right_arm_config.gripper_control_mode,
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.right_arm_config.gripper_mit_kp,
gripper_mit_kd=config.right_arm_config.gripper_mit_kd,
joint_limits=config.right_arm_config.joint_limits,
)
@@ -65,33 +65,18 @@ class RebotB601FollowerConfig:
}
)
# Max speed (deg/s) per joint for POS_VEL arms and FORCE_POS gripper (motor order).
pos_vel_velocity: float | list[float] = field(
default_factory=lambda: [150.0, 150.0, 150.0, 150.0, 150.0, 150.0, 500.0]
)
# Target velocity for joints running in POS_VEL mode, in degrees/s. A scalar is
# applied to every joint; a list provides one value per joint (in motor order).
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
# Arm control: "mit" or "pos_vel".
control_mode: str = "mit"
# MIT kp/kd per arm joint (motor order). Unused when control_mode="pos_vel".
mit_kp: float | list[float] = field(default_factory=lambda: [45.0, 45.0, 45.0, 8.0, 9.0, 8.0, 8.0])
mit_kd: float | list[float] = field(default_factory=lambda: [12.0, 12.0, 12.0, 1.0, 1.0, 1.0, 1.0])
# Gripper control: "force_pos" or "mit".
gripper_control_mode: str = "force_pos"
# FORCE_POS only: max grip force, in [0, 1].
gripper_torque_ratio: float = 0.05
# MIT only.
gripper_mit_kp: float = 8.0
gripper_mit_kd: float = 0.3
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
gripper_torque_ratio: float = 0.1
# Soft joint limits (degrees). These are clipped against on every action.
joint_limits: dict[str, tuple[float, float]] = field(
default_factory=lambda: {
"shoulder_pan": (-150.0, 150.0),
"shoulder_lift": (-200.0, 1.0),
"shoulder_pan": (-145.0, 145.0),
"shoulder_lift": (-170.0, 1.0),
"elbow_flex": (-200.0, 1.0),
"wrist_flex": (-80.0, 90.0),
"wrist_yaw": (-90.0, 90.0),
@@ -169,25 +169,11 @@ class RebotB601Follower(Robot):
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
if self.config.control_mode not in ("pos_vel", "mit"):
raise ValueError(
f"Unsupported control_mode '{self.config.control_mode}'. Use 'pos_vel' or 'mit'."
)
if self.config.gripper_control_mode not in ("force_pos", "mit"):
raise ValueError(
f"Unsupported gripper_control_mode '{self.config.gripper_control_mode}'. "
"Use 'force_pos' or 'mit'."
)
use_mit = self.config.control_mode == "mit"
gripper_use_mit = self.config.gripper_control_mode == "mit"
self.bus.enable_all()
for motor_name, motor in self.motors.items():
if motor_name == GRIPPER_MOTOR:
target_mode = MotorBridgeMode.MIT if gripper_use_mit else MotorBridgeMode.FORCE_POS
elif use_mit:
target_mode = MotorBridgeMode.MIT
else:
target_mode = MotorBridgeMode.POS_VEL
target_mode = (
MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else MotorBridgeMode.POS_VEL
)
for attempt in range(_ENSURE_MODE_RETRIES + 1):
try:
motor.ensure_mode(target_mode)
@@ -266,34 +252,22 @@ class RebotB601Follower(Robot):
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
use_mit = self.config.control_mode == "mit"
for motor_name, position_deg in goal_pos.items():
motor = self.motors.get(motor_name)
if motor is None:
continue
idx = self.motor_names.index(motor_name)
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
pos_rad = math.radians(position_deg)
vel_rad = math.radians(vel_deg_s)
if motor_name == GRIPPER_MOTOR:
if self.config.gripper_control_mode == "mit":
motor.send_mit(pos_rad, 0.0, self.config.gripper_mit_kp, self.config.gripper_mit_kd, 0.0)
else:
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_force_pos(pos_rad, math.radians(vel_deg_s), self.config.gripper_torque_ratio)
elif use_mit:
kp = self.config.mit_kp[idx] if isinstance(self.config.mit_kp, list) else self.config.mit_kp
kd = self.config.mit_kd[idx] if isinstance(self.config.mit_kd, list) else self.config.mit_kd
motor.send_mit(pos_rad, 0.0, kp, kd, 0.0)
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
else:
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_pos_vel(pos_rad, math.radians(vel_deg_s))
motor.send_pos_vel(pos_rad, vel_rad)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
+5 -15
View File
@@ -232,18 +232,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
# Now all other processes can safely load the dataset
if not is_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -389,19 +386,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
generator=sampler_generator,
)
else:
shuffle = True
@@ -65,7 +65,7 @@ class RebotArm102LeaderConfig:
joint_ranges: dict[str, list[int]] = field(
default_factory=lambda: {
"shoulder_pan": [-150, 150],
"shoulder_lift": [-200, 1],
"shoulder_lift": [-170, 1],
"elbow_flex": [-200, 1],
"wrist_flex": [-80, 90],
"wrist_yaw": [-90, 90],
-24
View File
@@ -114,30 +114,6 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
+3 -20
View File
@@ -91,11 +91,10 @@ def test_get_observation_converts_to_degrees(follower):
def test_send_action_clips_to_joint_limits(follower):
# shoulder_pan limit is (-150, 150); request beyond the upper bound.
# shoulder_pan limit is (-145, 145); request beyond the upper bound.
returned = follower.send_action({"shoulder_pan.pos": 999.0})
assert returned["shoulder_pan.pos"] == 150.0
# Default control_mode is "mit", so arm joints are driven via send_mit.
follower.motors["shoulder_pan"].send_mit.assert_called_once()
assert returned["shoulder_pan.pos"] == 145.0
follower.motors["shoulder_pan"].send_pos_vel.assert_called_once()
def test_send_action_routes_gripper_to_force_pos(follower):
@@ -104,22 +103,6 @@ def test_send_action_routes_gripper_to_force_pos(follower):
follower.motors["gripper"].send_pos_vel.assert_not_called()
def test_gripper_mit_mode_routes_to_send_mit():
bus_mock = _make_bus_mock()
with (
patch(f"{_MODULE}.require_package", lambda *a, **kw: None),
patch(f"{_MODULE}.MotorBridgeController") as controller_cls,
patch(f"{_MODULE}.MotorBridgeMode", MagicMock()),
):
controller_cls.from_dm_serial.return_value = bus_mock
cfg = RebotB601FollowerRobotConfig(port="/dev/null", gripper_control_mode="mit")
robot = RebotB601Follower(cfg)
robot.connect(calibrate=False)
robot.send_action({"gripper.pos": -10.0})
robot.motors["gripper"].send_mit.assert_called_once()
robot.motors["gripper"].send_force_pos.assert_not_called()
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebotB601FollowerConfig(