diff --git a/src/lerobot/configs/parser.py b/src/lerobot/configs/parser.py index d55fa44aa..46cff2b48 100644 --- a/src/lerobot/configs/parser.py +++ b/src/lerobot/configs/parser.py @@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> remaining = config_data[field] if remaining: _config_yaml_overrides[field] = _flatten_to_cli_args(remaining) - else: - del config_data[field] + del config_data[field] modified = True if not modified: @@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]: cli_args = filter_arg("config_path", cli_args) cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) else: - cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args) + if config_path_cli: + cli_args = filter_arg("config_path", cli_args) + cfg = draccus.parse( + config_class=argtype, + config_path=config_path_cli or config_path, + args=cli_args, + ) response = fn(cfg, *args, **kwargs) return response diff --git a/tests/test_yaml_policy_path.py b/tests/test_yaml_policy_path.py index 710a71c9a..8d8f7f2ec 100644 --- a/tests/test_yaml_policy_path.py +++ b/tests/test_yaml_policy_path.py @@ -1,10 +1,14 @@ """Tests for policy.path support in YAML config files (issue #2957).""" import json +import sys import tempfile +from dataclasses import dataclass, field +from unittest.mock import patch import yaml +from lerobot.configs import parser from lerobot.configs.parser import ( _config_path_args, _config_yaml_overrides, @@ -16,7 +20,8 @@ from lerobot.configs.parser import ( def test_extract_path_fields_from_yaml(): - """Test that policy.path is extracted from a YAML config and removed.""" + """Test that policy.path is extracted from a YAML config and the policy block + is removed entirely (siblings are captured separately as cli_overrides).""" config = { "dataset": {"repo_id": "lerobot/pusht"}, "policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False}, @@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml(): config_path = f.name _config_path_args.clear() + _config_yaml_overrides.clear() cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) # Path should be extracted and stored assert _config_path_args["policy"] == "lerobot/smolvla_base" - # Cleaned config should not have the path field + # Cleaned config should not have the policy block at all -- draccus must not + # try to decode it as PreTrainedConfig; the actual config comes from + # from_pretrained(path) with the captured overrides applied on top. with open(cleaned_path) as f: cleaned = yaml.safe_load(f) - assert "path" not in cleaned["policy"] - assert cleaned["policy"]["type"] == "smolvla" - assert cleaned["policy"]["push_to_hub"] is False + assert "policy" not in cleaned # Original dataset should be untouched assert cleaned["dataset"]["repo_id"] == "lerobot/pusht" + # Sibling overrides (excluding type/path) captured for from_pretrained. + overrides = get_yaml_overrides("policy") + assert any("push_to_hub=false" in o for o in overrides) + _config_path_args.clear() + _config_yaml_overrides.clear() def test_extract_path_fields_from_json(): - """Test that policy.path is extracted from a JSON config.""" + """Test that policy.path is extracted from a JSON config and the policy + block is removed entirely.""" config = { "policy": {"type": "act", "path": "some/local/path"}, } @@ -54,15 +66,17 @@ def test_extract_path_fields_from_json(): config_path = f.name _config_path_args.clear() + _config_yaml_overrides.clear() cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) assert _config_path_args["policy"] == "some/local/path" with open(cleaned_path) as f: cleaned = json.load(f) - assert "path" not in cleaned["policy"] + assert "policy" not in cleaned _config_path_args.clear() + _config_yaml_overrides.clear() def test_extract_no_path_returns_original(): @@ -216,3 +230,91 @@ def test_flatten_nested_with_bools(): args = _flatten_to_cli_args(d) assert "--optimizer.use_warmup=true" in args assert "--optimizer.lr=0.01" in args + + +def test_extract_removes_field_with_siblings_and_no_type(): + """Regression: when policy.path has siblings but no type:, the entire policy + block must still be removed from the cleaned config. Otherwise draccus tries + to decode the leftover dict as PreTrainedConfig and crashes on the missing + type discriminator. + """ + config = { + "dataset": {"repo_id": "lerobot/pusht"}, + "policy": { + "path": "lerobot/smolvla_base", + "n_action_steps": 10, + "dtype": "bfloat16", + }, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) + + with open(cleaned_path) as f: + cleaned = yaml.safe_load(f) or {} + assert "policy" not in cleaned, "policy block should be fully removed when path is present" + assert cleaned["dataset"]["repo_id"] == "lerobot/pusht" + assert _config_path_args["policy"] == "lerobot/smolvla_base" + overrides = get_yaml_overrides("policy") + assert any("n_action_steps=10" in o for o in overrides) + assert any("dtype=bfloat16" in o for o in overrides) + + _config_path_args.clear() + _config_yaml_overrides.clear() + + +@dataclass +class _DummyNested: + foo: int = 0 + + +@dataclass +class _DummyConfig: + nested: _DummyNested = field(default_factory=_DummyNested) + other: str = "default" + + @classmethod + def __get_path_fields__(cls): + return ["nested"] + + +def test_wrap_uses_cleaned_config_for_draccus_parse(): + """Regression: wrap() updates config_path_cli to point at the cleaned temp + file but must propagate that to the draccus.parse fallback branch. Without + the fix, cli_args still contains --config_path= and draccus reads + the original YAML with `path:` still in it, crashing on the unknown field. + """ + config = { + "nested": {"path": "some/checkpoint", "foo": 42}, + "other": "set-via-yaml", + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + + captured: dict = {} + + @parser.wrap() + def main(cfg: _DummyConfig) -> _DummyConfig: + captured["cfg"] = cfg + return cfg + + with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]): + main() + + assert captured["cfg"].other == "set-via-yaml" + assert _config_path_args["nested"] == "some/checkpoint" + # Cleaned config dropped `nested:` entirely; defaults stand for this wrapper + # class (a real PreTrainedConfig would now load the checkpoint and apply + # the captured yaml_overrides via from_pretrained()). + assert captured["cfg"].nested.foo == 0 + + _config_path_args.clear() + _config_yaml_overrides.clear()