mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Fix GR00T N1.7 RTC action decoding
This commit is contained in:
@@ -319,6 +319,85 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
def _prepare_n1_7_rtc_inputs(
|
||||
self,
|
||||
inputs: dict[str, Tensor],
|
||||
*,
|
||||
inference_delay: object,
|
||||
prev_chunk_left_over: object,
|
||||
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
|
||||
if self.config.model_version != GROOT_N1_7 or prev_chunk_left_over is None:
|
||||
return inputs, None
|
||||
if not isinstance(prev_chunk_left_over, torch.Tensor):
|
||||
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
|
||||
if prev_chunk_left_over.numel() == 0:
|
||||
return inputs, None
|
||||
|
||||
prev_actions = prev_chunk_left_over
|
||||
if prev_actions.ndim == 2:
|
||||
prev_actions = prev_actions.unsqueeze(0)
|
||||
elif prev_actions.ndim != 3:
|
||||
raise ValueError(
|
||||
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
|
||||
)
|
||||
|
||||
state = inputs.get("state")
|
||||
if state is None:
|
||||
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
|
||||
batch_size = state.shape[0]
|
||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||
elif prev_actions.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
|
||||
)
|
||||
|
||||
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
|
||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||
if prev_actions.shape[1] > model_action_horizon:
|
||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||
|
||||
action_horizon = int(prev_actions.shape[1])
|
||||
if action_horizon <= 0:
|
||||
return inputs, None
|
||||
|
||||
if prev_actions.shape[2] > max_action_dim:
|
||||
prev_actions = prev_actions[:, :, :max_action_dim]
|
||||
elif prev_actions.shape[2] < max_action_dim:
|
||||
pad = torch.zeros(
|
||||
prev_actions.shape[0],
|
||||
prev_actions.shape[1],
|
||||
max_action_dim - prev_actions.shape[2],
|
||||
dtype=prev_actions.dtype,
|
||||
device=prev_actions.device,
|
||||
)
|
||||
prev_actions = torch.cat([prev_actions, pad], dim=2)
|
||||
|
||||
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
|
||||
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
|
||||
overlap_steps = max(0, min(action_horizon, execution_horizon))
|
||||
if overlap_steps == 0:
|
||||
return inputs, None
|
||||
|
||||
try:
|
||||
frozen_steps = int(inference_delay or 0)
|
||||
except (TypeError, ValueError):
|
||||
frozen_steps = 0
|
||||
frozen_steps = max(0, min(frozen_steps, overlap_steps))
|
||||
|
||||
options = {
|
||||
"action_horizon": action_horizon,
|
||||
"rtc_overlap_steps": overlap_steps,
|
||||
"rtc_frozen_steps": frozen_steps,
|
||||
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
|
||||
}
|
||||
|
||||
inputs = dict(inputs)
|
||||
inputs["action"] = prev_actions
|
||||
return inputs, options
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Training forward pass.
|
||||
|
||||
@@ -342,14 +421,13 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **_: object) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
|
||||
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
||||
|
||||
Returns a tensor of shape (B, n_action_steps, action_dim).
|
||||
|
||||
Groot does not currently implement LeRobot's RTC guidance contract. Accept
|
||||
and ignore action-selection kwargs so the RTC engine can still use Groot as
|
||||
an async chunk producer.
|
||||
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
|
||||
action-overlap options before calling the underlying model.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
@@ -357,13 +435,23 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
# During inference, we do not pass action because it is predicted.
|
||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
|
||||
groot_options = None
|
||||
if self.config.model_version == GROOT_N1_7:
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
if groot_options is not None:
|
||||
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
|
||||
else:
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
|
||||
@@ -494,6 +494,17 @@ def _legacy_groot_processor_overrides(
|
||||
return preprocessor_overrides, postprocessor_overrides
|
||||
|
||||
|
||||
def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
||||
path = Path(pretrained_path).expanduser()
|
||||
if not path.is_dir():
|
||||
return False
|
||||
config = _read_json(path / config_filename)
|
||||
steps = config.get("steps", [])
|
||||
if not isinstance(steps, list):
|
||||
return False
|
||||
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
|
||||
|
||||
|
||||
def make_groot_pre_post_processors_from_pretrained(
|
||||
config: GrootConfig,
|
||||
pretrained_path: str,
|
||||
@@ -517,12 +528,26 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
dataset_stats=dataset_stats,
|
||||
)
|
||||
|
||||
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
||||
config=config,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
)
|
||||
if (
|
||||
config.model_version == GROOT_N1_7
|
||||
and _local_processor_config_has_step(
|
||||
pretrained_path,
|
||||
postprocessor_config_filename,
|
||||
"groot_n1_7_action_decode_v1",
|
||||
)
|
||||
):
|
||||
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
|
||||
# action decoder. Adding the legacy action-unpack override would target
|
||||
# a step that is not present and break loading.
|
||||
preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
else:
|
||||
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
||||
config=config,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
)
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=preprocessor_config_filename,
|
||||
@@ -1712,6 +1737,15 @@ def _unnormalize_min_max(action: np.ndarray, min_v: np.ndarray, max_v: np.ndarra
|
||||
return (np.clip(action, -1.0, 1.0) + 1.0) * 0.5 * (max_v - min_v) + min_v
|
||||
|
||||
|
||||
def _n1_7_decode_valid_horizon(action_config: dict[str, Any], action_np: np.ndarray) -> int | None:
|
||||
if action_np.ndim != 3:
|
||||
return None
|
||||
delta_indices = action_config.get("delta_indices", [])
|
||||
if not isinstance(delta_indices, list) or not delta_indices:
|
||||
return None
|
||||
return max(1, min(action_np.shape[1], len(delta_indices)))
|
||||
|
||||
|
||||
def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
|
||||
rows = rot6d.reshape(2, 3).astype(np.float64)
|
||||
row1 = rows[0] / np.linalg.norm(rows[0])
|
||||
@@ -1824,6 +1858,9 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
return transition
|
||||
|
||||
action_np = action.detach().cpu().float().numpy()
|
||||
valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np)
|
||||
if valid_horizon is not None:
|
||||
action_np = action_np[:, :valid_horizon]
|
||||
decoded_groups: dict[str, np.ndarray] = {}
|
||||
start_idx = 0
|
||||
for idx, key in enumerate(action_keys):
|
||||
|
||||
@@ -334,13 +334,15 @@ class _DummyGrootModel(nn.Module):
|
||||
self.config = SimpleNamespace(compute_dtype="float32")
|
||||
self.compute_dtype = "float32"
|
||||
self.forward_inputs = None
|
||||
self.get_action_options = None
|
||||
|
||||
def forward(self, inputs):
|
||||
self.forward_inputs = dict(inputs)
|
||||
return {"loss": self.weight + 1.0}
|
||||
|
||||
def get_action(self, inputs):
|
||||
def get_action(self, inputs, options=None):
|
||||
self.forward_inputs = dict(inputs)
|
||||
self.get_action_options = options
|
||||
batch_size = inputs["state"].shape[0]
|
||||
return {"action_pred": torch.zeros(batch_size, 40, 132, device=self.weight.device)}
|
||||
|
||||
@@ -427,6 +429,35 @@ def test_groot_predict_action_chunk_accepts_rtc_kwargs():
|
||||
signature.bind(object(), {}, inference_delay=2, prev_chunk_left_over=None)
|
||||
|
||||
|
||||
def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
|
||||
config = _groot_config(GROOT_N1_7)
|
||||
policy = GrootPolicy(config)
|
||||
policy.config.rtc_config = SimpleNamespace(execution_horizon=6)
|
||||
|
||||
prev_chunk = torch.arange(8 * 7, dtype=torch.float32).view(8, 7)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
{"state": torch.zeros(1, 1, 132)},
|
||||
inference_delay=3,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
)
|
||||
|
||||
assert actions.shape == (1, 40, 7)
|
||||
assert dummy_model.get_action_options == {
|
||||
"action_horizon": 8,
|
||||
"rtc_overlap_steps": 6,
|
||||
"rtc_frozen_steps": 3,
|
||||
"rtc_ramp_rate": 6.0,
|
||||
}
|
||||
assert dummy_model.forward_inputs["action"].shape == (1, 8, 132)
|
||||
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, :7], prev_chunk)
|
||||
torch.testing.assert_close(dummy_model.forward_inputs["action"][0, :, 7:], torch.zeros(8, 125))
|
||||
|
||||
|
||||
def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
|
||||
model_path = tmp_path / "GR00T-N1.7-local"
|
||||
model_path.mkdir()
|
||||
@@ -593,6 +624,27 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
|
||||
assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0]
|
||||
|
||||
|
||||
def test_converted_raw_n1_7_processors_load_without_legacy_action_unpack_override(tmp_path):
|
||||
model_path = tmp_path / "libero_spatial"
|
||||
_write_raw_n1_7_libero_checkpoint(model_path)
|
||||
config = _raw_n1_7_libero_config(model_path)
|
||||
preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path))
|
||||
save_dir = tmp_path / "converted_pretrained_model"
|
||||
|
||||
config.save_pretrained(save_dir)
|
||||
preprocessor.save_pretrained(save_dir)
|
||||
postprocessor.save_pretrained(save_dir)
|
||||
|
||||
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
|
||||
config,
|
||||
pretrained_path=str(save_dir),
|
||||
preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}},
|
||||
)
|
||||
|
||||
assert any(isinstance(step, GrootN17PackInputsStep) for step in loaded_preprocessor.steps)
|
||||
assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps)
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_rejects_state_dim_above_core_max():
|
||||
step = GrootN17PackInputsStep(
|
||||
max_state_dim=2,
|
||||
@@ -941,6 +993,76 @@ def test_groot_n1_7_action_decode_applies_named_libero_transform_from_modality_k
|
||||
torch.testing.assert_close(output[TransitionKey.ACTION], expected)
|
||||
|
||||
|
||||
def test_groot_n1_7_action_decode_truncates_to_valid_horizon_for_relative_stats():
|
||||
arm_min = [[-1.0] * 5 for _ in range(16)]
|
||||
arm_max = [[1.0] * 5 for _ in range(16)]
|
||||
raw_stats = {
|
||||
"state": {
|
||||
"single_arm": _stats([0.0] * 5),
|
||||
"gripper": _stats([0.0]),
|
||||
},
|
||||
"action": {
|
||||
"single_arm": _stats([0.0] * 5),
|
||||
"gripper": {
|
||||
"min": [0.0],
|
||||
"max": [10.0],
|
||||
"mean": [5.0],
|
||||
"std": [1.0],
|
||||
"q01": [0.0],
|
||||
"q99": [10.0],
|
||||
},
|
||||
},
|
||||
"relative_action": {
|
||||
"single_arm": {
|
||||
"min": arm_min,
|
||||
"max": arm_max,
|
||||
"mean": [[0.0] * 5 for _ in range(16)],
|
||||
"std": [[1.0] * 5 for _ in range(16)],
|
||||
"q01": arm_min,
|
||||
"q99": arm_max,
|
||||
},
|
||||
},
|
||||
}
|
||||
modality_config = {
|
||||
"state": {
|
||||
"modality_keys": ["single_arm", "gripper"],
|
||||
},
|
||||
"action": {
|
||||
"delta_indices": list(range(16)),
|
||||
"modality_keys": ["single_arm", "gripper"],
|
||||
"action_configs": [
|
||||
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
|
||||
{"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
|
||||
],
|
||||
},
|
||||
}
|
||||
pack_step = GrootN17PackInputsStep(
|
||||
raw_stats=raw_stats,
|
||||
modality_config=modality_config,
|
||||
normalize_min_max=False,
|
||||
)
|
||||
pack_step(
|
||||
{
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.zeros(1, 6)},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
)
|
||||
decode_step = GrootN17ActionDecodeStep(
|
||||
env_action_dim=6,
|
||||
raw_stats=raw_stats,
|
||||
modality_config=modality_config,
|
||||
use_relative_action=True,
|
||||
pack_step=pack_step,
|
||||
)
|
||||
|
||||
output = decode_step({TransitionKey.ACTION: torch.zeros(1, 40, 6)})
|
||||
|
||||
decoded = output[TransitionKey.ACTION]
|
||||
assert decoded.shape == (1, 16, 6)
|
||||
torch.testing.assert_close(decoded[..., :5], torch.zeros(1, 16, 5))
|
||||
torch.testing.assert_close(decoded[..., 5], torch.full((1, 16), 5.0))
|
||||
|
||||
|
||||
def test_groot_n1_7_action_decode_requires_gripper_key_for_libero_transform():
|
||||
step = GrootN17ActionDecodeStep(
|
||||
env_action_dim=1,
|
||||
|
||||
Reference in New Issue
Block a user