Format GR00T OSS parity changes

This commit is contained in:
Andy Wrenn
2026-06-28 12:48:55 -07:00
parent bdc05c89e3
commit 4a3f46d0ec
2 changed files with 10 additions and 8 deletions
@@ -665,7 +665,9 @@ def _relative_action_chunks_by_horizon(
return chunks
def _compute_horizon_relative_action_stats(chunks_by_horizon: list[list[np.ndarray]]) -> dict[str, np.ndarray]:
def _compute_horizon_relative_action_stats(
chunks_by_horizon: list[list[np.ndarray]],
) -> dict[str, np.ndarray]:
if not chunks_by_horizon or not any(chunks_by_horizon):
raise ValueError("Cannot compute relative action statistics without unpadded action vectors.")
@@ -779,7 +781,9 @@ def _make_relative_action_training_stats(
num_vectors += len(vectors)
if num_vectors < 2:
raise ValueError("Cannot compute relative action statistics from fewer than 2 unpadded action vectors.")
raise ValueError(
"Cannot compute relative action statistics from fewer than 2 unpadded action vectors."
)
stats[ACTION] = _compute_horizon_relative_action_stats(chunks_by_horizon or [])
return stats
@@ -1091,7 +1095,9 @@ def make_groot_pre_post_processors(
if config.use_relative_actions and not checkpoint_has_stats:
relative_dataset_stats = dataset_stats
if not _stats_preserve_action_horizon(relative_dataset_stats):
relative_dataset_stats = _make_relative_action_training_stats_from_dataset_meta(config, dataset_meta)
relative_dataset_stats = _make_relative_action_training_stats_from_dataset_meta(
config, dataset_meta
)
relative_assets = _build_n1_7_relative_action_processor_assets(
config,
relative_dataset_stats,
@@ -1652,7 +1658,6 @@ class GrootN17PackInputsStep(ProcessorStep):
return None
return torch.cat(normalized_groups, dim=-1)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
+1 -4
View File
@@ -1870,7 +1870,6 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
pytest.importorskip("transformers")
from lerobot.policies.groot import processor_groot
calls = []
@@ -2116,9 +2115,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
assert decode_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_dataset_meta(
monkeypatch, tmp_path
):
def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_dataset_meta(monkeypatch, tmp_path):
input_features, output_features = _groot_features(state_dim=6, action_dim=6)
action_names = [
"shoulder_pan.pos",