mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 08:37:10 +00:00
Address GROOT relative action review feedback
This commit is contained in:
@@ -287,6 +287,7 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
pretrained_path=pretrained_path,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
|
||||
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
|
||||
preprocessor_config_filename=kwargs.get(
|
||||
@@ -399,6 +400,7 @@ def make_pre_post_processors(
|
||||
processors = make_groot_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, XVLAConfig):
|
||||
|
||||
@@ -268,7 +268,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
)
|
||||
|
||||
# Groot-specific model parameters
|
||||
model_version: str = GROOT_N1_7
|
||||
|
||||
# Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and
|
||||
# checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the
|
||||
@@ -325,11 +324,10 @@ class GrootConfig(PreTrainedConfig):
|
||||
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
|
||||
use_flash_attention: bool = False
|
||||
|
||||
# Train on state-relative action chunks. The listed joints stay absolute, which is normally used
|
||||
# for gripper channels whose command frame is not the arm joint state.
|
||||
# Enable GR00T-style state-relative action chunks. Prefer deriving action representation from
|
||||
# embodiment metadata; relative_exclude_joints is a flat-vector override for datasets without it.
|
||||
use_relative_actions: bool = False
|
||||
relative_exclude_joints: list[str] = field(default_factory=list)
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Training parameters
|
||||
optimizer_lr: float = 1e-4
|
||||
@@ -365,8 +363,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
@@ -417,9 +413,9 @@ class GrootConfig(PreTrainedConfig):
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
if inferred_version is not None and inferred_version != GROOT_N1_7:
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
|
||||
@@ -435,6 +435,7 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
pretrained_path: str,
|
||||
*,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_meta: Any | None = None,
|
||||
preprocessor_overrides: dict[str, Any] | None = None,
|
||||
postprocessor_overrides: dict[str, Any] | None = None,
|
||||
preprocessor_config_filename: str = f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json",
|
||||
@@ -457,6 +458,7 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config=processor_cfg,
|
||||
dataset_stats=dataset_stats,
|
||||
dataset_meta=dataset_meta,
|
||||
)
|
||||
# Raw checkpoints have no serialized pipelines to load overrides into,
|
||||
# so apply the caller overrides (e.g. device and rename_map from
|
||||
@@ -546,8 +548,20 @@ def _reconnect_groot_n1_7_pack_decode_steps(
|
||||
step.pack_step = pack_step
|
||||
|
||||
|
||||
def _resolve_action_feature_names_from_dataset_meta(dataset_meta: Any | None) -> list[str] | None:
|
||||
features = getattr(dataset_meta, "features", {}) or {}
|
||||
action_feature = features.get(ACTION) if isinstance(features, dict) else None
|
||||
if isinstance(action_feature, dict):
|
||||
names = action_feature.get("names")
|
||||
else:
|
||||
names = getattr(action_feature, "names", None)
|
||||
return list(names) if names is not None else None
|
||||
|
||||
|
||||
def make_groot_pre_post_processors(
|
||||
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: GrootConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_meta: Any | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
@@ -660,7 +674,7 @@ def make_groot_pre_post_processors(
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=list(config.relative_exclude_joints or []),
|
||||
action_names=config.action_feature_names,
|
||||
action_names=_resolve_action_feature_names_from_dataset_meta(dataset_meta),
|
||||
)
|
||||
input_steps.insert(2, relative_step)
|
||||
|
||||
|
||||
@@ -246,11 +246,7 @@ def _iter_action_state_training_samples(dataset: Any):
|
||||
yield item.get(ACTION), item.get(OBS_STATE), item.get(f"{ACTION}_is_pad")
|
||||
|
||||
|
||||
def _resolve_action_feature_names(active_cfg: Any, dataset: Any) -> list[str] | None:
|
||||
config_names = getattr(active_cfg, "action_feature_names", None)
|
||||
if config_names is not None:
|
||||
return list(config_names)
|
||||
|
||||
def _resolve_action_feature_names(dataset: Any) -> list[str] | None:
|
||||
features = getattr(getattr(dataset, "meta", None), "features", {}) or {}
|
||||
action_feature = features.get(ACTION) if isinstance(features, dict) else None
|
||||
if isinstance(action_feature, dict):
|
||||
@@ -464,14 +460,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
processor_stats = _make_relative_action_training_stats(
|
||||
dataset,
|
||||
exclude_joints=getattr(active_cfg, "relative_exclude_joints", []),
|
||||
action_names=_resolve_action_feature_names(active_cfg, dataset),
|
||||
action_names=_resolve_action_feature_names(dataset),
|
||||
)
|
||||
|
||||
processor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
processor_kwargs["dataset_stats"] = processor_stats
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
if cfg.is_reward_model_training or getattr(active_cfg, "use_relative_actions", False):
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
@@ -495,7 +491,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor_overrides["relative_actions_processor"] = {
|
||||
"enabled": True,
|
||||
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
|
||||
"action_names": _resolve_action_feature_names(active_cfg, dataset),
|
||||
"action_names": _resolve_action_feature_names(dataset),
|
||||
}
|
||||
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
|
||||
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
|
||||
@@ -67,11 +67,10 @@ def _groot_features(
|
||||
)
|
||||
|
||||
|
||||
def _groot_config(model_version: str = GROOT_N1_7) -> GrootConfig:
|
||||
def _groot_config() -> GrootConfig:
|
||||
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
|
||||
kwargs = {"action_decode_transform": GROOT_ACTION_DECODE_TRANSFORM_LIBERO}
|
||||
return GrootConfig(
|
||||
model_version=model_version,
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
@@ -83,7 +82,6 @@ def _groot_config(model_version: str = GROOT_N1_7) -> GrootConfig:
|
||||
def _raw_n1_7_libero_config(model_path) -> GrootConfig:
|
||||
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
|
||||
return GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path=str(model_path),
|
||||
embodiment_tag="libero_sim",
|
||||
input_features=input_features,
|
||||
@@ -351,7 +349,6 @@ class _DummyGrootModel(nn.Module):
|
||||
def test_groot_defaults_use_n1_7():
|
||||
config = GrootConfig(device="cpu")
|
||||
|
||||
assert config.model_version == GROOT_N1_7
|
||||
assert config.base_model_path == GROOT_N1_7_BASE_MODEL
|
||||
assert config.max_state_dim == 132
|
||||
assert config.max_action_dim == 132
|
||||
@@ -362,7 +359,6 @@ def test_groot_defaults_use_n1_7():
|
||||
|
||||
def test_groot_n1_7_accepts_named_action_decode_transform():
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
action_decode_transform="libero",
|
||||
device="cpu",
|
||||
)
|
||||
@@ -374,23 +370,15 @@ def test_groot_n1_7_accepts_named_action_decode_transform():
|
||||
def test_groot_n1_7_rejects_legacy_libero_gripper_action_decode_transform(legacy_transform):
|
||||
with pytest.raises(ValueError, match="Unsupported GR00T N1.7 action decode transform"):
|
||||
GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
action_decode_transform=legacy_transform,
|
||||
action_decode_transform=legacy_transform,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"])
|
||||
def test_groot_rejects_n1_5_aliases(legacy_version):
|
||||
with pytest.raises(ValueError, match="Unsupported GR00T model_version"):
|
||||
GrootConfig(model_version=legacy_version, device="cpu")
|
||||
|
||||
|
||||
def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7():
|
||||
with pytest.raises(ValueError, match="does not match base_model_path"):
|
||||
GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
@@ -398,10 +386,9 @@ def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7():
|
||||
def test_groot_n1_7_can_be_selected_from_policy_config_factory_without_external_gr00t():
|
||||
sys.modules.pop("gr00t", None)
|
||||
|
||||
config = make_policy_config("groot", model_version=GROOT_N1_7, device="cpu")
|
||||
config = make_policy_config("groot", device="cpu")
|
||||
|
||||
assert isinstance(config, GrootConfig)
|
||||
assert config.model_version == GROOT_N1_7
|
||||
assert "gr00t" not in sys.modules
|
||||
|
||||
|
||||
@@ -417,7 +404,7 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
|
||||
config = _groot_config(GROOT_N1_7)
|
||||
config = _groot_config()
|
||||
policy = GrootPolicy(config)
|
||||
policy.config.rtc_config = SimpleNamespace(execution_horizon=6)
|
||||
|
||||
@@ -446,7 +433,7 @@ def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch):
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
|
||||
config = _groot_config(GROOT_N1_7)
|
||||
config = _groot_config()
|
||||
policy = GrootPolicy(config)
|
||||
policy.config.rtc_config = SimpleNamespace(execution_horizon=6)
|
||||
|
||||
@@ -490,7 +477,6 @@ def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(t
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel()))
|
||||
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path=str(model_path),
|
||||
embodiment_tag="libero_sim",
|
||||
input_features=input_features,
|
||||
@@ -518,8 +504,7 @@ def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
|
||||
# so construction itself raises before from_pretrained is reached.
|
||||
with pytest.raises(ValueError, match="does not match base_model_path"):
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
@@ -534,13 +519,12 @@ def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatc
|
||||
|
||||
model_path = tmp_path / "GR00T-N1.7-local"
|
||||
model_path.mkdir()
|
||||
config = _groot_config(GROOT_N1_7)
|
||||
config = _groot_config()
|
||||
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: _DummyGrootModel()))
|
||||
|
||||
policy = GrootPolicy.from_pretrained(model_path, config=config)
|
||||
|
||||
assert policy.config.model_version == GROOT_N1_7
|
||||
assert policy.config.base_model_path == str(model_path)
|
||||
|
||||
|
||||
@@ -555,7 +539,6 @@ def test_groot_from_pretrained_infers_n1_7_from_ambiguous_local_config(tmp_path,
|
||||
|
||||
policy = GrootPolicy.from_pretrained(model_path)
|
||||
|
||||
assert policy.config.model_version == GROOT_N1_7
|
||||
assert policy.config.base_model_path == str(model_path)
|
||||
|
||||
|
||||
@@ -1340,8 +1323,7 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(
|
||||
# so construction itself raises before from_pretrained is reached.
|
||||
with pytest.raises(ValueError, match="does not match base_model_path"):
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
base_model_path="nvidia/GR00T-N1.5-3B",
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
@@ -1353,7 +1335,7 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(
|
||||
|
||||
def test_groot_n1_7_processors_are_registered_lazily_without_external_gr00t():
|
||||
sys.modules.pop("gr00t", None)
|
||||
config = _groot_config(GROOT_N1_7)
|
||||
config = _groot_config()
|
||||
|
||||
preprocessor, _ = make_groot_pre_post_processors(config)
|
||||
step_types = {type(step) for step in preprocessor.steps}
|
||||
@@ -1692,7 +1674,7 @@ def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
|
||||
|
||||
|
||||
def test_groot_n1_7_saved_processors_reload_through_factory(tmp_path):
|
||||
config = _groot_config(GROOT_N1_7)
|
||||
config = _groot_config()
|
||||
dataset_stats = {
|
||||
OBS_STATE: {
|
||||
"min": torch.zeros(8),
|
||||
@@ -1724,7 +1706,7 @@ def test_groot_n1_7_saved_processors_reload_through_factory(tmp_path):
|
||||
|
||||
|
||||
def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stats(tmp_path):
|
||||
config = _groot_config(GROOT_N1_7)
|
||||
config = _groot_config()
|
||||
saved_stats = {
|
||||
OBS_STATE: {
|
||||
"min": torch.full((8,), -2.0),
|
||||
@@ -1774,7 +1756,6 @@ def test_groot_n1_7_relative_action_training_processors_save_relative_action_sta
|
||||
action_decode_transform=None,
|
||||
use_relative_actions=True,
|
||||
relative_exclude_joints=["gripper"],
|
||||
action_feature_names=action_names,
|
||||
)
|
||||
absolute_dataset_stats = {
|
||||
OBS_STATE: {
|
||||
@@ -1829,7 +1810,9 @@ def test_groot_n1_7_relative_action_training_processors_save_relative_action_sta
|
||||
"max": torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0, 100.0]),
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=relative_dataset_stats)
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config, dataset_stats=relative_dataset_stats, dataset_meta=_RelativeStatsDataset.meta
|
||||
)
|
||||
preprocessor.save_pretrained(tmp_path)
|
||||
postprocessor.save_pretrained(tmp_path)
|
||||
|
||||
@@ -1867,7 +1850,7 @@ def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(fake_from_pretrained))
|
||||
|
||||
policy = GrootPolicy(_groot_config(GROOT_N1_7))
|
||||
policy = GrootPolicy(_groot_config())
|
||||
|
||||
assert called["pretrained_model_name_or_path"] == GROOT_N1_7_BASE_MODEL
|
||||
assert isinstance(policy._groot_model, _DummyGrootModel)
|
||||
@@ -1878,7 +1861,7 @@ def test_groot_policy_forwards_n1_7_qwen_inputs(monkeypatch):
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model))
|
||||
policy = GrootPolicy(_groot_config(GROOT_N1_7))
|
||||
policy = GrootPolicy(_groot_config())
|
||||
|
||||
batch = {
|
||||
"state": torch.zeros(2, 1, 132),
|
||||
@@ -1941,7 +1924,6 @@ def test_groot_n1_7_select_action_uses_checkpoint_valid_horizon(tmp_path, monkey
|
||||
monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel()))
|
||||
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
|
||||
config = GrootConfig(
|
||||
model_version=GROOT_N1_7,
|
||||
base_model_path=str(model_path),
|
||||
embodiment_tag="libero_sim",
|
||||
input_features=input_features,
|
||||
|
||||
Reference in New Issue
Block a user