mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 08:37:10 +00:00
Fix GROOT relative action padding and RTC leftovers
This commit is contained in:
@@ -73,6 +73,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
# Initialize GR00T model using ported components
|
||||
self._groot_model = self._create_groot_model()
|
||||
self._action_queue_steps = self._resolve_action_queue_steps()
|
||||
self._warned_native_relative_rtc_prefix_disabled = False
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -306,6 +307,17 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
|
||||
if prev_chunk_left_over is None:
|
||||
return inputs, None
|
||||
if getattr(self.config, "use_relative_actions", False):
|
||||
# Generic RTC only provides normalized leftovers from the previous chunk. For
|
||||
# native relative-action N1.7 checkpoints those rows are tied to the old
|
||||
# observation state and old per-horizon stats row, so using them as the next
|
||||
# prefix can push the policy in the wrong direction. Run without native RTC
|
||||
# overlap guidance until a GROOT-specific RTC path can pass re-anchored
|
||||
# absolute leftovers through.
|
||||
if not getattr(self, "_warned_native_relative_rtc_prefix_disabled", False):
|
||||
logger.info("Disabling native GR00T RTC prefix for relative-action policy")
|
||||
self._warned_native_relative_rtc_prefix_disabled = True
|
||||
return inputs, None
|
||||
if not isinstance(prev_chunk_left_over, torch.Tensor):
|
||||
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
|
||||
if prev_chunk_left_over.numel() == 0:
|
||||
|
||||
@@ -588,6 +588,9 @@ def _slice_stats_entry(stats: dict[str, Any], indices: list[int]) -> dict[str, A
|
||||
max_index = max(indices)
|
||||
sliced: dict[str, Any] = {}
|
||||
for stat_name, value in stats.items():
|
||||
if stat_name == "count":
|
||||
sliced[stat_name] = torch.as_tensor(value).flatten().tolist()
|
||||
continue
|
||||
tensor = torch.as_tensor(value, dtype=torch.float32)
|
||||
if tensor.ndim >= 2:
|
||||
if tensor.shape[-1] <= max_index:
|
||||
@@ -1537,13 +1540,34 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
)
|
||||
action = torch.cat([action, pad], dim=1)
|
||||
horizon = self.action_horizon
|
||||
if valid_horizon < horizon:
|
||||
horizon_valid = torch.zeros(bsz, horizon, dtype=torch.bool, device=action.device)
|
||||
horizon_valid[:, :valid_horizon] = True
|
||||
action_is_pad = comp.get(f"{ACTION}_is_pad")
|
||||
if action_is_pad is None:
|
||||
action_is_pad = comp.get("action_horizon_is_pad")
|
||||
if action_is_pad is not None:
|
||||
action_pad = torch.as_tensor(action_is_pad, dtype=torch.bool, device=action.device)
|
||||
if action_pad.ndim == 1:
|
||||
if bsz == 1 and action_pad.numel() == horizon:
|
||||
action_pad = action_pad.unsqueeze(0)
|
||||
elif horizon == 1 and action_pad.numel() == bsz:
|
||||
action_pad = action_pad.view(bsz, 1)
|
||||
if action_pad.ndim != 2 or action_pad.shape[0] != bsz:
|
||||
raise ValueError(
|
||||
"action_is_pad must have shape (B, T) matching the action batch; "
|
||||
f"got {tuple(action_pad.shape)} for action {tuple(action.shape)}."
|
||||
)
|
||||
pad_horizon = min(horizon, action_pad.shape[1])
|
||||
horizon_valid[:, :pad_horizon] &= ~action_pad[:, :pad_horizon]
|
||||
|
||||
if valid_horizon < horizon or action_is_pad is not None:
|
||||
action = action.clone()
|
||||
action[:, valid_horizon:, :] = 0
|
||||
action = action * horizon_valid.unsqueeze(-1).to(dtype=action.dtype)
|
||||
action_mask = torch.zeros(
|
||||
bsz, horizon, self.max_action_dim, dtype=torch.float32, device=action.device
|
||||
)
|
||||
action_mask[:, :valid_horizon, :valid_dim] = 1.0
|
||||
action_mask[:, :, :valid_dim] = horizon_valid.unsqueeze(-1).to(dtype=action_mask.dtype)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
comp["action_mask"] = action_mask
|
||||
|
||||
|
||||
@@ -1074,6 +1074,22 @@ def test_groot_n1_7_pack_inputs_trains_native_relative_groups_with_absolute_grip
|
||||
torch.testing.assert_close(output[TransitionKey.ACTION], expected_actions)
|
||||
|
||||
|
||||
def test_groot_policy_ignores_rtc_leftovers_for_relative_actions():
|
||||
policy = object.__new__(GrootPolicy)
|
||||
policy.config = SimpleNamespace(use_relative_actions=True)
|
||||
policy._warned_native_relative_rtc_prefix_disabled = False
|
||||
inputs = {"state": torch.zeros(1, 1, 132)}
|
||||
|
||||
output_inputs, options = policy._prepare_n1_7_rtc_inputs(
|
||||
inputs,
|
||||
inference_delay=1,
|
||||
prev_chunk_left_over=torch.ones(8, 6),
|
||||
)
|
||||
|
||||
assert output_inputs is inputs
|
||||
assert options is None
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_adds_inference_action_horizon_mask():
|
||||
step = GrootN17PackInputsStep(
|
||||
action_horizon=40,
|
||||
@@ -1098,6 +1114,49 @@ def test_groot_n1_7_pack_inputs_adds_inference_action_horizon_mask():
|
||||
assert output[TransitionKey.COMPLEMENTARY_DATA]["embodiment_id"].dtype == torch.int32
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_masks_padded_action_horizons():
|
||||
step = GrootN17PackInputsStep(
|
||||
action_horizon=4,
|
||||
valid_action_horizon=4,
|
||||
max_state_dim=3,
|
||||
max_action_dim=5,
|
||||
normalize_min_max=False,
|
||||
)
|
||||
action = torch.arange(2 * 4 * 3, dtype=torch.float32).view(2, 4, 3)
|
||||
action_is_pad = torch.tensor(
|
||||
[
|
||||
[False, True, False, True],
|
||||
[True, False, False, False],
|
||||
]
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
OBS_STATE: torch.zeros(2, 3),
|
||||
},
|
||||
TransitionKey.ACTION: action.clone(),
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"task": ["Move", "Place"],
|
||||
"action_is_pad": action_is_pad,
|
||||
},
|
||||
}
|
||||
|
||||
output = step(transition)
|
||||
|
||||
expected_valid = (~action_is_pad).float()
|
||||
action_mask = output[TransitionKey.COMPLEMENTARY_DATA]["action_mask"]
|
||||
assert action_mask.shape == (2, 4, 5)
|
||||
torch.testing.assert_close(action_mask[..., :3], expected_valid.unsqueeze(-1).expand(-1, -1, 3))
|
||||
assert action_mask[..., 3:].sum().item() == 0
|
||||
|
||||
packed_action = output[TransitionKey.ACTION]
|
||||
assert packed_action.shape == (2, 4, 5)
|
||||
torch.testing.assert_close(packed_action[0, 0, :3], action[0, 0])
|
||||
torch.testing.assert_close(packed_action[0, 2, :3], action[0, 2])
|
||||
assert packed_action[0, 1].abs().sum().item() == 0
|
||||
assert packed_action[0, 3].abs().sum().item() == 0
|
||||
assert packed_action[1, 0].abs().sum().item() == 0
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_orders_video_by_checkpoint_modality_keys():
|
||||
step = GrootN17PackInputsStep(
|
||||
normalize_min_max=False,
|
||||
@@ -1904,6 +1963,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
|
||||
[-2.0, -3.0, -4.0, -5.0, -6.0],
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
]
|
||||
assert pack_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2]
|
||||
assert pack_config["raw_stats"]["action"]["gripper"]["min"] == [0.0]
|
||||
assert pack_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
|
||||
|
||||
@@ -1926,6 +1986,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
|
||||
[-1.0, -2.0, -3.0, -4.0, -5.0],
|
||||
[2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
]
|
||||
assert decode_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2]
|
||||
assert decode_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user