Address GROOT relative action review feedback

This commit is contained in:
Andy Wrenn
2026-06-20 06:30:50 -07:00
parent 57d4cd4840
commit 55da4bf8aa
5 changed files with 43 additions and 53 deletions
+2
View File
@@ -287,6 +287,7 @@ def make_pre_post_processors(
config=policy_cfg,
pretrained_path=pretrained_path,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
preprocessor_config_filename=kwargs.get(
@@ -399,6 +400,7 @@ def make_pre_post_processors(
processors = make_groot_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, XVLAConfig):
@@ -268,7 +268,6 @@ class GrootConfig(PreTrainedConfig):
)
# Groot-specific model parameters
model_version: str = GROOT_N1_7
# Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and
# checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the
@@ -325,11 +324,10 @@ class GrootConfig(PreTrainedConfig):
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
use_flash_attention: bool = False
# Train on state-relative action chunks. The listed joints stay absolute, which is normally used
# for gripper channels whose command frame is not the arm joint state.
# Enable GR00T-style state-relative action chunks. Prefer deriving action representation from
# embodiment metadata; relative_exclude_joints is a flat-vector override for datasets without it.
use_relative_actions: bool = False
relative_exclude_joints: list[str] = field(default_factory=list)
action_feature_names: list[str] | None = None
# Training parameters
optimizer_lr: float = 1e-4
@@ -365,8 +363,6 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
self.model_version = normalize_groot_model_version(self.model_version)
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
@@ -417,9 +413,9 @@ class GrootConfig(PreTrainedConfig):
setattr(self, field_name, n1_7_value)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
if inferred_version is not None and inferred_version != GROOT_N1_7:
message = (
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
+16 -2
View File
@@ -435,6 +435,7 @@ def make_groot_pre_post_processors_from_pretrained(
pretrained_path: str,
*,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_meta: Any | None = None,
preprocessor_overrides: dict[str, Any] | None = None,
postprocessor_overrides: dict[str, Any] | None = None,
preprocessor_config_filename: str = f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json",
@@ -457,6 +458,7 @@ def make_groot_pre_post_processors_from_pretrained(
preprocessor, postprocessor = make_groot_pre_post_processors(
config=processor_cfg,
dataset_stats=dataset_stats,
dataset_meta=dataset_meta,
)
# Raw checkpoints have no serialized pipelines to load overrides into,
# so apply the caller overrides (e.g. device and rename_map from
@@ -546,8 +548,20 @@ def _reconnect_groot_n1_7_pack_decode_steps(
step.pack_step = pack_step
def _resolve_action_feature_names_from_dataset_meta(dataset_meta: Any | None) -> list[str] | None:
features = getattr(dataset_meta, "features", {}) or {}
action_feature = features.get(ACTION) if isinstance(features, dict) else None
if isinstance(action_feature, dict):
names = action_feature.get("names")
else:
names = getattr(action_feature, "names", None)
return list(names) if names is not None else None
def make_groot_pre_post_processors(
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
config: GrootConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_meta: Any | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
@@ -660,7 +674,7 @@ def make_groot_pre_post_processors(
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=list(config.relative_exclude_joints or []),
action_names=config.action_feature_names,
action_names=_resolve_action_feature_names_from_dataset_meta(dataset_meta),
)
input_steps.insert(2, relative_step)
+4 -8
View File
@@ -246,11 +246,7 @@ def _iter_action_state_training_samples(dataset: Any):
yield item.get(ACTION), item.get(OBS_STATE), item.get(f"{ACTION}_is_pad")
def _resolve_action_feature_names(active_cfg: Any, dataset: Any) -> list[str] | None:
config_names = getattr(active_cfg, "action_feature_names", None)
if config_names is not None:
return list(config_names)
def _resolve_action_feature_names(dataset: Any) -> list[str] | None:
features = getattr(getattr(dataset, "meta", None), "features", {}) or {}
action_feature = features.get(ACTION) if isinstance(features, dict) else None
if isinstance(action_feature, dict):
@@ -464,14 +460,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
processor_stats = _make_relative_action_training_stats(
dataset,
exclude_joints=getattr(active_cfg, "relative_exclude_joints", []),
action_names=_resolve_action_feature_names(active_cfg, dataset),
action_names=_resolve_action_feature_names(dataset),
)
processor_kwargs = {}
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
processor_kwargs["dataset_stats"] = processor_stats
if cfg.is_reward_model_training:
if cfg.is_reward_model_training or getattr(active_cfg, "use_relative_actions", False):
processor_kwargs["dataset_meta"] = dataset.meta
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
@@ -495,7 +491,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
preprocessor_overrides["relative_actions_processor"] = {
"enabled": True,
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
"action_names": _resolve_action_feature_names(active_cfg, dataset),
"action_names": _resolve_action_feature_names(dataset),
}
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
+17 -35
View File
@@ -67,11 +67,10 @@ def _groot_features(
)
def _groot_config(model_version: str = GROOT_N1_7) -> GrootConfig:
def _groot_config() -> GrootConfig:
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
kwargs = {"action_decode_transform": GROOT_ACTION_DECODE_TRANSFORM_LIBERO}
return GrootConfig(
model_version=model_version,
input_features=input_features,
output_features=output_features,
device="cpu",
@@ -83,7 +82,6 @@ def _groot_config(model_version: str = GROOT_N1_7) -> GrootConfig:
def _raw_n1_7_libero_config(model_path) -> GrootConfig:
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
return GrootConfig(
model_version=GROOT_N1_7,
base_model_path=str(model_path),
embodiment_tag="libero_sim",
input_features=input_features,
@@ -351,7 +349,6 @@ class _DummyGrootModel(nn.Module):
def test_groot_defaults_use_n1_7():
config = GrootConfig(device="cpu")
assert config.model_version == GROOT_N1_7
assert config.base_model_path == GROOT_N1_7_BASE_MODEL
assert config.max_state_dim == 132
assert config.max_action_dim == 132
@@ -362,7 +359,6 @@ def test_groot_defaults_use_n1_7():
def test_groot_n1_7_accepts_named_action_decode_transform():
config = GrootConfig(
model_version=GROOT_N1_7,
action_decode_transform="libero",
device="cpu",
)
@@ -374,23 +370,15 @@ def test_groot_n1_7_accepts_named_action_decode_transform():
def test_groot_n1_7_rejects_legacy_libero_gripper_action_decode_transform(legacy_transform):
with pytest.raises(ValueError, match="Unsupported GR00T N1.7 action decode transform"):
GrootConfig(
model_version=GROOT_N1_7,
action_decode_transform=legacy_transform,
action_decode_transform=legacy_transform,
device="cpu",
)
@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"])
def test_groot_rejects_n1_5_aliases(legacy_version):
with pytest.raises(ValueError, match="Unsupported GR00T model_version"):
GrootConfig(model_version=legacy_version, device="cpu")
def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7():
with pytest.raises(ValueError, match="does not match base_model_path"):
GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
base_model_path="nvidia/GR00T-N1.5-3B",
device="cpu",
)
@@ -398,10 +386,9 @@ def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7():
def test_groot_n1_7_can_be_selected_from_policy_config_factory_without_external_gr00t():
sys.modules.pop("gr00t", None)
config = make_policy_config("groot", model_version=GROOT_N1_7, device="cpu")
config = make_policy_config("groot", device="cpu")
assert isinstance(config, GrootConfig)
assert config.model_version == GROOT_N1_7
assert "gr00t" not in sys.modules
@@ -417,7 +404,7 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
dummy_model = _DummyGrootModel()
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
config = _groot_config(GROOT_N1_7)
config = _groot_config()
policy = GrootPolicy(config)
policy.config.rtc_config = SimpleNamespace(execution_horizon=6)
@@ -446,7 +433,7 @@ def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch):
dummy_model = _DummyGrootModel()
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
config = _groot_config(GROOT_N1_7)
config = _groot_config()
policy = GrootPolicy(config)
policy.config.rtc_config = SimpleNamespace(execution_horizon=6)
@@ -490,7 +477,6 @@ def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(t
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel()))
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path=str(model_path),
embodiment_tag="libero_sim",
input_features=input_features,
@@ -518,8 +504,7 @@ def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
# so construction itself raises before from_pretrained is reached.
with pytest.raises(ValueError, match="does not match base_model_path"):
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
base_model_path="nvidia/GR00T-N1.5-3B",
input_features=input_features,
output_features=output_features,
device="cpu",
@@ -534,13 +519,12 @@ def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatc
model_path = tmp_path / "GR00T-N1.7-local"
model_path.mkdir()
config = _groot_config(GROOT_N1_7)
config = _groot_config()
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: _DummyGrootModel()))
policy = GrootPolicy.from_pretrained(model_path, config=config)
assert policy.config.model_version == GROOT_N1_7
assert policy.config.base_model_path == str(model_path)
@@ -555,7 +539,6 @@ def test_groot_from_pretrained_infers_n1_7_from_ambiguous_local_config(tmp_path,
policy = GrootPolicy.from_pretrained(model_path)
assert policy.config.model_version == GROOT_N1_7
assert policy.config.base_model_path == str(model_path)
@@ -1340,8 +1323,7 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(
# so construction itself raises before from_pretrained is reached.
with pytest.raises(ValueError, match="does not match base_model_path"):
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
base_model_path="nvidia/GR00T-N1.5-3B",
input_features=input_features,
output_features=output_features,
device="cpu",
@@ -1353,7 +1335,7 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(
def test_groot_n1_7_processors_are_registered_lazily_without_external_gr00t():
sys.modules.pop("gr00t", None)
config = _groot_config(GROOT_N1_7)
config = _groot_config()
preprocessor, _ = make_groot_pre_post_processors(config)
step_types = {type(step) for step in preprocessor.steps}
@@ -1692,7 +1674,7 @@ def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
def test_groot_n1_7_saved_processors_reload_through_factory(tmp_path):
config = _groot_config(GROOT_N1_7)
config = _groot_config()
dataset_stats = {
OBS_STATE: {
"min": torch.zeros(8),
@@ -1724,7 +1706,7 @@ def test_groot_n1_7_saved_processors_reload_through_factory(tmp_path):
def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stats(tmp_path):
config = _groot_config(GROOT_N1_7)
config = _groot_config()
saved_stats = {
OBS_STATE: {
"min": torch.full((8,), -2.0),
@@ -1774,7 +1756,6 @@ def test_groot_n1_7_relative_action_training_processors_save_relative_action_sta
action_decode_transform=None,
use_relative_actions=True,
relative_exclude_joints=["gripper"],
action_feature_names=action_names,
)
absolute_dataset_stats = {
OBS_STATE: {
@@ -1829,7 +1810,9 @@ def test_groot_n1_7_relative_action_training_processors_save_relative_action_sta
"max": torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0, 100.0]),
}
preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=relative_dataset_stats)
preprocessor, postprocessor = make_groot_pre_post_processors(
config, dataset_stats=relative_dataset_stats, dataset_meta=_RelativeStatsDataset.meta
)
preprocessor.save_pretrained(tmp_path)
postprocessor.save_pretrained(tmp_path)
@@ -1867,7 +1850,7 @@ def test_groot_policy_selects_n1_7_model_class(monkeypatch):
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(fake_from_pretrained))
policy = GrootPolicy(_groot_config(GROOT_N1_7))
policy = GrootPolicy(_groot_config())
assert called["pretrained_model_name_or_path"] == GROOT_N1_7_BASE_MODEL
assert isinstance(policy._groot_model, _DummyGrootModel)
@@ -1878,7 +1861,7 @@ def test_groot_policy_forwards_n1_7_qwen_inputs(monkeypatch):
dummy_model = _DummyGrootModel()
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
policy = GrootPolicy(_groot_config(GROOT_N1_7))
policy = GrootPolicy(_groot_config())
batch = {
"state": torch.zeros(2, 1, 132),
@@ -1941,7 +1924,6 @@ def test_groot_n1_7_select_action_uses_checkpoint_valid_horizon(tmp_path, monkey
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel()))
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path=str(model_path),
embodiment_tag="libero_sim",
input_features=input_features,