Fix GROOT relative action padding and RTC leftovers

This commit is contained in:
Andy Wrenn
2026-06-21 08:13:39 -07:00
parent e3f3ddd92a
commit ca60066cf6
3 changed files with 99 additions and 2 deletions
@@ -73,6 +73,7 @@ class GrootPolicy(PreTrainedPolicy):
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self._action_queue_steps = self._resolve_action_queue_steps()
self._warned_native_relative_rtc_prefix_disabled = False
self.reset()
@@ -306,6 +307,17 @@ class GrootPolicy(PreTrainedPolicy):
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
if prev_chunk_left_over is None:
return inputs, None
if getattr(self.config, "use_relative_actions", False):
# Generic RTC only provides normalized leftovers from the previous chunk. For
# native relative-action N1.7 checkpoints those rows are tied to the old
# observation state and old per-horizon stats row, so using them as the next
# prefix can push the policy in the wrong direction. Run without native RTC
# overlap guidance until a GROOT-specific RTC path can pass re-anchored
# absolute leftovers through.
if not getattr(self, "_warned_native_relative_rtc_prefix_disabled", False):
logger.info("Disabling native GR00T RTC prefix for relative-action policy")
self._warned_native_relative_rtc_prefix_disabled = True
return inputs, None
if not isinstance(prev_chunk_left_over, torch.Tensor):
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
if prev_chunk_left_over.numel() == 0:
+26 -2
View File
@@ -589,6 +589,9 @@ def _slice_stats_entry(stats: dict[str, Any], indices: list[int]) -> dict[str, A
max_index = max(indices)
sliced: dict[str, Any] = {}
for stat_name, value in stats.items():
if stat_name == "count":
sliced[stat_name] = torch.as_tensor(value).flatten().tolist()
continue
tensor = torch.as_tensor(value, dtype=torch.float32)
if tensor.ndim >= 2:
if tensor.shape[-1] <= max_index:
@@ -1531,13 +1534,34 @@ class GrootN17PackInputsStep(ProcessorStep):
)
action = torch.cat([action, pad], dim=1)
horizon = self.action_horizon
if valid_horizon < horizon:
horizon_valid = torch.zeros(bsz, horizon, dtype=torch.bool, device=action.device)
horizon_valid[:, :valid_horizon] = True
action_is_pad = comp.get(f"{ACTION}_is_pad")
if action_is_pad is None:
action_is_pad = comp.get("action_horizon_is_pad")
if action_is_pad is not None:
action_pad = torch.as_tensor(action_is_pad, dtype=torch.bool, device=action.device)
if action_pad.ndim == 1:
if bsz == 1 and action_pad.numel() == horizon:
action_pad = action_pad.unsqueeze(0)
elif horizon == 1 and action_pad.numel() == bsz:
action_pad = action_pad.view(bsz, 1)
if action_pad.ndim != 2 or action_pad.shape[0] != bsz:
raise ValueError(
"action_is_pad must have shape (B, T) matching the action batch; "
f"got {tuple(action_pad.shape)} for action {tuple(action.shape)}."
)
pad_horizon = min(horizon, action_pad.shape[1])
horizon_valid[:, :pad_horizon] &= ~action_pad[:, :pad_horizon]
if valid_horizon < horizon or action_is_pad is not None:
action = action.clone()
action[:, valid_horizon:, :] = 0
action = action * horizon_valid.unsqueeze(-1).to(dtype=action.dtype)
action_mask = torch.zeros(
bsz, horizon, self.max_action_dim, dtype=torch.float32, device=action.device
)
action_mask[:, :valid_horizon, :valid_dim] = 1.0
action_mask[:, :, :valid_dim] = horizon_valid.unsqueeze(-1).to(dtype=action_mask.dtype)
transition[TransitionKey.ACTION] = action
comp["action_mask"] = action_mask
+61
View File
@@ -1074,6 +1074,22 @@ def test_groot_n1_7_pack_inputs_trains_native_relative_groups_with_absolute_grip
torch.testing.assert_close(output[TransitionKey.ACTION], expected_actions)
def test_groot_policy_ignores_rtc_leftovers_for_relative_actions():
policy = object.__new__(GrootPolicy)
policy.config = SimpleNamespace(use_relative_actions=True)
policy._warned_native_relative_rtc_prefix_disabled = False
inputs = {"state": torch.zeros(1, 1, 132)}
output_inputs, options = policy._prepare_n1_7_rtc_inputs(
inputs,
inference_delay=1,
prev_chunk_left_over=torch.ones(8, 6),
)
assert output_inputs is inputs
assert options is None
def test_groot_n1_7_pack_inputs_adds_inference_action_horizon_mask():
step = GrootN17PackInputsStep(
action_horizon=40,
@@ -1098,6 +1114,49 @@ def test_groot_n1_7_pack_inputs_adds_inference_action_horizon_mask():
assert output[TransitionKey.COMPLEMENTARY_DATA]["embodiment_id"].dtype == torch.int32
def test_groot_n1_7_pack_inputs_masks_padded_action_horizons():
step = GrootN17PackInputsStep(
action_horizon=4,
valid_action_horizon=4,
max_state_dim=3,
max_action_dim=5,
normalize_min_max=False,
)
action = torch.arange(2 * 4 * 3, dtype=torch.float32).view(2, 4, 3)
action_is_pad = torch.tensor(
[
[False, True, False, True],
[True, False, False, False],
]
)
transition = {
TransitionKey.OBSERVATION: {
OBS_STATE: torch.zeros(2, 3),
},
TransitionKey.ACTION: action.clone(),
TransitionKey.COMPLEMENTARY_DATA: {
"task": ["Move", "Place"],
"action_is_pad": action_is_pad,
},
}
output = step(transition)
expected_valid = (~action_is_pad).float()
action_mask = output[TransitionKey.COMPLEMENTARY_DATA]["action_mask"]
assert action_mask.shape == (2, 4, 5)
torch.testing.assert_close(action_mask[..., :3], expected_valid.unsqueeze(-1).expand(-1, -1, 3))
assert action_mask[..., 3:].sum().item() == 0
packed_action = output[TransitionKey.ACTION]
assert packed_action.shape == (2, 4, 5)
torch.testing.assert_close(packed_action[0, 0, :3], action[0, 0])
torch.testing.assert_close(packed_action[0, 2, :3], action[0, 2])
assert packed_action[0, 1].abs().sum().item() == 0
assert packed_action[0, 3].abs().sum().item() == 0
assert packed_action[1, 0].abs().sum().item() == 0
def test_groot_n1_7_pack_inputs_orders_video_by_checkpoint_modality_keys():
step = GrootN17PackInputsStep(
normalize_min_max=False,
@@ -1904,6 +1963,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
[-2.0, -3.0, -4.0, -5.0, -6.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
]
assert pack_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2]
assert pack_config["raw_stats"]["action"]["gripper"]["min"] == [0.0]
assert pack_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
@@ -1926,6 +1986,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
[-1.0, -2.0, -3.0, -4.0, -5.0],
[2.0, 3.0, 4.0, 5.0, 6.0],
]
assert decode_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2]
assert decode_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]