diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index b8ec83bc3..cbfe647f7 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -73,6 +73,7 @@ class GrootPolicy(PreTrainedPolicy): # Initialize GR00T model using ported components self._groot_model = self._create_groot_model() self._action_queue_steps = self._resolve_action_queue_steps() + self._warned_native_relative_rtc_prefix_disabled = False self.reset() @@ -306,6 +307,17 @@ class GrootPolicy(PreTrainedPolicy): ) -> tuple[dict[str, Tensor], dict[str, object] | None]: if prev_chunk_left_over is None: return inputs, None + if getattr(self.config, "use_relative_actions", False): + # Generic RTC only provides normalized leftovers from the previous chunk. For + # native relative-action N1.7 checkpoints those rows are tied to the old + # observation state and old per-horizon stats row, so using them as the next + # prefix can push the policy in the wrong direction. Run without native RTC + # overlap guidance until a GROOT-specific RTC path can pass re-anchored + # absolute leftovers through. + if not getattr(self, "_warned_native_relative_rtc_prefix_disabled", False): + logger.info("Disabling native GR00T RTC prefix for relative-action policy") + self._warned_native_relative_rtc_prefix_disabled = True + return inputs, None if not isinstance(prev_chunk_left_over, torch.Tensor): raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.") if prev_chunk_left_over.numel() == 0: diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index c0208dbe5..32438333c 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -589,6 +589,9 @@ def _slice_stats_entry(stats: dict[str, Any], indices: list[int]) -> dict[str, A max_index = max(indices) sliced: dict[str, Any] = {} for stat_name, value in stats.items(): + if stat_name == "count": + sliced[stat_name] = torch.as_tensor(value).flatten().tolist() + continue tensor = torch.as_tensor(value, dtype=torch.float32) if tensor.ndim >= 2: if tensor.shape[-1] <= max_index: @@ -1531,13 +1534,34 @@ class GrootN17PackInputsStep(ProcessorStep): ) action = torch.cat([action, pad], dim=1) horizon = self.action_horizon - if valid_horizon < horizon: + horizon_valid = torch.zeros(bsz, horizon, dtype=torch.bool, device=action.device) + horizon_valid[:, :valid_horizon] = True + action_is_pad = comp.get(f"{ACTION}_is_pad") + if action_is_pad is None: + action_is_pad = comp.get("action_horizon_is_pad") + if action_is_pad is not None: + action_pad = torch.as_tensor(action_is_pad, dtype=torch.bool, device=action.device) + if action_pad.ndim == 1: + if bsz == 1 and action_pad.numel() == horizon: + action_pad = action_pad.unsqueeze(0) + elif horizon == 1 and action_pad.numel() == bsz: + action_pad = action_pad.view(bsz, 1) + if action_pad.ndim != 2 or action_pad.shape[0] != bsz: + raise ValueError( + "action_is_pad must have shape (B, T) matching the action batch; " + f"got {tuple(action_pad.shape)} for action {tuple(action.shape)}." + ) + pad_horizon = min(horizon, action_pad.shape[1]) + horizon_valid[:, :pad_horizon] &= ~action_pad[:, :pad_horizon] + + if valid_horizon < horizon or action_is_pad is not None: action = action.clone() action[:, valid_horizon:, :] = 0 + action = action * horizon_valid.unsqueeze(-1).to(dtype=action.dtype) action_mask = torch.zeros( bsz, horizon, self.max_action_dim, dtype=torch.float32, device=action.device ) - action_mask[:, :valid_horizon, :valid_dim] = 1.0 + action_mask[:, :, :valid_dim] = horizon_valid.unsqueeze(-1).to(dtype=action_mask.dtype) transition[TransitionKey.ACTION] = action comp["action_mask"] = action_mask diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index d37b18f90..6b89271c9 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -1074,6 +1074,22 @@ def test_groot_n1_7_pack_inputs_trains_native_relative_groups_with_absolute_grip torch.testing.assert_close(output[TransitionKey.ACTION], expected_actions) +def test_groot_policy_ignores_rtc_leftovers_for_relative_actions(): + policy = object.__new__(GrootPolicy) + policy.config = SimpleNamespace(use_relative_actions=True) + policy._warned_native_relative_rtc_prefix_disabled = False + inputs = {"state": torch.zeros(1, 1, 132)} + + output_inputs, options = policy._prepare_n1_7_rtc_inputs( + inputs, + inference_delay=1, + prev_chunk_left_over=torch.ones(8, 6), + ) + + assert output_inputs is inputs + assert options is None + + def test_groot_n1_7_pack_inputs_adds_inference_action_horizon_mask(): step = GrootN17PackInputsStep( action_horizon=40, @@ -1098,6 +1114,49 @@ def test_groot_n1_7_pack_inputs_adds_inference_action_horizon_mask(): assert output[TransitionKey.COMPLEMENTARY_DATA]["embodiment_id"].dtype == torch.int32 +def test_groot_n1_7_pack_inputs_masks_padded_action_horizons(): + step = GrootN17PackInputsStep( + action_horizon=4, + valid_action_horizon=4, + max_state_dim=3, + max_action_dim=5, + normalize_min_max=False, + ) + action = torch.arange(2 * 4 * 3, dtype=torch.float32).view(2, 4, 3) + action_is_pad = torch.tensor( + [ + [False, True, False, True], + [True, False, False, False], + ] + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.zeros(2, 3), + }, + TransitionKey.ACTION: action.clone(), + TransitionKey.COMPLEMENTARY_DATA: { + "task": ["Move", "Place"], + "action_is_pad": action_is_pad, + }, + } + + output = step(transition) + + expected_valid = (~action_is_pad).float() + action_mask = output[TransitionKey.COMPLEMENTARY_DATA]["action_mask"] + assert action_mask.shape == (2, 4, 5) + torch.testing.assert_close(action_mask[..., :3], expected_valid.unsqueeze(-1).expand(-1, -1, 3)) + assert action_mask[..., 3:].sum().item() == 0 + + packed_action = output[TransitionKey.ACTION] + assert packed_action.shape == (2, 4, 5) + torch.testing.assert_close(packed_action[0, 0, :3], action[0, 0]) + torch.testing.assert_close(packed_action[0, 2, :3], action[0, 2]) + assert packed_action[0, 1].abs().sum().item() == 0 + assert packed_action[0, 3].abs().sum().item() == 0 + assert packed_action[1, 0].abs().sum().item() == 0 + + def test_groot_n1_7_pack_inputs_orders_video_by_checkpoint_modality_keys(): step = GrootN17PackInputsStep( normalize_min_max=False, @@ -1904,6 +1963,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat [-2.0, -3.0, -4.0, -5.0, -6.0], [1.0, 2.0, 3.0, 4.0, 5.0], ] + assert pack_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2] assert pack_config["raw_stats"]["action"]["gripper"]["min"] == [0.0] assert pack_config["raw_stats"]["action"]["gripper"]["max"] == [100.0] @@ -1926,6 +1986,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat [-1.0, -2.0, -3.0, -4.0, -5.0], [2.0, 3.0, 4.0, 5.0, 6.0], ] + assert decode_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2] assert decode_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]