Fix policy.path in YAML configs (PR #3145 followup) (#3597)

PR #3145 added YAML support for policy.path but left two bugs:

1. extract_path_fields_from_config only deleted config_data[field] when
   no sibling overrides existed. With siblings, the dict stayed in place
   and draccus crashed decoding it as PreTrainedConfig (no 'type' key).
   Sibling overrides go into _config_yaml_overrides and are applied later
   by from_pretrained(), so the field can always be removed.

2. wrap() updated config_path_cli to the cleaned temp file path but
   never propagated it to the draccus.parse fallback branch. cli_args
   still contained --config_path=<original>, so draccus read the
   original YAML with path: still present.

Tests passed because they (a) called extract_path_fields_from_config
directly and (b) included type: alongside path: in the YAML, sidestepping
both bugs.

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Reece O'Mahoney
2026-05-26 13:01:19 +01:00
committed by GitHub
parent 8194897994
commit f65f3f7a4a
2 changed files with 117 additions and 10 deletions
+7 -2
View File
@@ -255,7 +255,6 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
remaining = config_data[field]
if remaining:
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
else:
del config_data[field]
modified = True
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
cli_args = filter_arg("config_path", cli_args)
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
else:
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
if config_path_cli:
cli_args = filter_arg("config_path", cli_args)
cfg = draccus.parse(
config_class=argtype,
config_path=config_path_cli or config_path,
args=cli_args,
)
response = fn(cfg, *args, **kwargs)
return response
+109 -7
View File
@@ -1,10 +1,14 @@
"""Tests for policy.path support in YAML config files (issue #2957)."""
import json
import sys
import tempfile
from dataclasses import dataclass, field
from unittest.mock import patch
import yaml
from lerobot.configs import parser
from lerobot.configs.parser import (
_config_path_args,
_config_yaml_overrides,
@@ -16,7 +20,8 @@ from lerobot.configs.parser import (
def test_extract_path_fields_from_yaml():
"""Test that policy.path is extracted from a YAML config and removed."""
"""Test that policy.path is extracted from a YAML config and the policy block
is removed entirely (siblings are captured separately as cli_overrides)."""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
@@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml():
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
# Path should be extracted and stored
assert _config_path_args["policy"] == "lerobot/smolvla_base"
# Cleaned config should not have the path field
# Cleaned config should not have the policy block at all -- draccus must not
# try to decode it as PreTrainedConfig; the actual config comes from
# from_pretrained(path) with the captured overrides applied on top.
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f)
assert "path" not in cleaned["policy"]
assert cleaned["policy"]["type"] == "smolvla"
assert cleaned["policy"]["push_to_hub"] is False
assert "policy" not in cleaned
# Original dataset should be untouched
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
# Sibling overrides (excluding type/path) captured for from_pretrained.
overrides = get_yaml_overrides("policy")
assert any("push_to_hub=false" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_extract_path_fields_from_json():
"""Test that policy.path is extracted from a JSON config."""
"""Test that policy.path is extracted from a JSON config and the policy
block is removed entirely."""
config = {
"policy": {"type": "act", "path": "some/local/path"},
}
@@ -54,15 +66,17 @@ def test_extract_path_fields_from_json():
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
assert _config_path_args["policy"] == "some/local/path"
with open(cleaned_path) as f:
cleaned = json.load(f)
assert "path" not in cleaned["policy"]
assert "policy" not in cleaned
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_extract_no_path_returns_original():
@@ -216,3 +230,91 @@ def test_flatten_nested_with_bools():
args = _flatten_to_cli_args(d)
assert "--optimizer.use_warmup=true" in args
assert "--optimizer.lr=0.01" in args
def test_extract_removes_field_with_siblings_and_no_type():
"""Regression: when policy.path has siblings but no type:, the entire policy
block must still be removed from the cleaned config. Otherwise draccus tries
to decode the leftover dict as PreTrainedConfig and crashes on the missing
type discriminator.
"""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {
"path": "lerobot/smolvla_base",
"n_action_steps": 10,
"dtype": "bfloat16",
},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f) or {}
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
assert _config_path_args["policy"] == "lerobot/smolvla_base"
overrides = get_yaml_overrides("policy")
assert any("n_action_steps=10" in o for o in overrides)
assert any("dtype=bfloat16" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
@dataclass
class _DummyNested:
foo: int = 0
@dataclass
class _DummyConfig:
nested: _DummyNested = field(default_factory=_DummyNested)
other: str = "default"
@classmethod
def __get_path_fields__(cls):
return ["nested"]
def test_wrap_uses_cleaned_config_for_draccus_parse():
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
file but must propagate that to the draccus.parse fallback branch. Without
the fix, cli_args still contains --config_path=<original> and draccus reads
the original YAML with `path:` still in it, crashing on the unknown field.
"""
config = {
"nested": {"path": "some/checkpoint", "foo": 42},
"other": "set-via-yaml",
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
captured: dict = {}
@parser.wrap()
def main(cfg: _DummyConfig) -> _DummyConfig:
captured["cfg"] = cfg
return cfg
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
main()
assert captured["cfg"].other == "set-via-yaml"
assert _config_path_args["nested"] == "some/checkpoint"
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
# class (a real PreTrainedConfig would now load the checkpoint and apply
# the captured yaml_overrides via from_pretrained()).
assert captured["cfg"].nested.foo == 0
_config_path_args.clear()
_config_yaml_overrides.clear()