mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 17:17:01 +00:00
Format GR00T OSS parity changes
This commit is contained in:
@@ -665,7 +665,9 @@ def _relative_action_chunks_by_horizon(
|
||||
return chunks
|
||||
|
||||
|
||||
def _compute_horizon_relative_action_stats(chunks_by_horizon: list[list[np.ndarray]]) -> dict[str, np.ndarray]:
|
||||
def _compute_horizon_relative_action_stats(
|
||||
chunks_by_horizon: list[list[np.ndarray]],
|
||||
) -> dict[str, np.ndarray]:
|
||||
if not chunks_by_horizon or not any(chunks_by_horizon):
|
||||
raise ValueError("Cannot compute relative action statistics without unpadded action vectors.")
|
||||
|
||||
@@ -779,7 +781,9 @@ def _make_relative_action_training_stats(
|
||||
num_vectors += len(vectors)
|
||||
|
||||
if num_vectors < 2:
|
||||
raise ValueError("Cannot compute relative action statistics from fewer than 2 unpadded action vectors.")
|
||||
raise ValueError(
|
||||
"Cannot compute relative action statistics from fewer than 2 unpadded action vectors."
|
||||
)
|
||||
|
||||
stats[ACTION] = _compute_horizon_relative_action_stats(chunks_by_horizon or [])
|
||||
return stats
|
||||
@@ -1091,7 +1095,9 @@ def make_groot_pre_post_processors(
|
||||
if config.use_relative_actions and not checkpoint_has_stats:
|
||||
relative_dataset_stats = dataset_stats
|
||||
if not _stats_preserve_action_horizon(relative_dataset_stats):
|
||||
relative_dataset_stats = _make_relative_action_training_stats_from_dataset_meta(config, dataset_meta)
|
||||
relative_dataset_stats = _make_relative_action_training_stats_from_dataset_meta(
|
||||
config, dataset_meta
|
||||
)
|
||||
relative_assets = _build_n1_7_relative_action_processor_assets(
|
||||
config,
|
||||
relative_dataset_stats,
|
||||
@@ -1652,7 +1658,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
return None
|
||||
return torch.cat(normalized_groups, dim=-1)
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
|
||||
@@ -1870,7 +1870,6 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
|
||||
def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
|
||||
from lerobot.policies.groot import processor_groot
|
||||
|
||||
calls = []
|
||||
@@ -2116,9 +2115,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
|
||||
assert decode_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
|
||||
|
||||
|
||||
def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_dataset_meta(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_dataset_meta(monkeypatch, tmp_path):
|
||||
input_features, output_features = _groot_features(state_dim=6, action_dim=6)
|
||||
action_names = [
|
||||
"shoulder_pan.pos",
|
||||
|
||||
Reference in New Issue
Block a user