mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
PR #3145 added YAML support for policy.path but left two bugs: 1. extract_path_fields_from_config only deleted config_data[field] when no sibling overrides existed. With siblings, the dict stayed in place and draccus crashed decoding it as PreTrainedConfig (no 'type' key). Sibling overrides go into _config_yaml_overrides and are applied later by from_pretrained(), so the field can always be removed. 2. wrap() updated config_path_cli to the cleaned temp file path but never propagated it to the draccus.parse fallback branch. cli_args still contained --config_path=<original>, so draccus read the original YAML with path: still present. Tests passed because they (a) called extract_path_fields_from_config directly and (b) included type: alongside path: in the YAML, sidestepping both bugs. Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -255,8 +255,7 @@ def extract_path_fields_from_config(config_path: str, path_fields: list[str]) ->
|
|||||||
remaining = config_data[field]
|
remaining = config_data[field]
|
||||||
if remaining:
|
if remaining:
|
||||||
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||||
else:
|
del config_data[field]
|
||||||
del config_data[field]
|
|
||||||
modified = True
|
modified = True
|
||||||
|
|
||||||
if not modified:
|
if not modified:
|
||||||
@@ -311,7 +310,13 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
|||||||
cli_args = filter_arg("config_path", cli_args)
|
cli_args = filter_arg("config_path", cli_args)
|
||||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
cfg = draccus.parse(config_class=argtype, config_path=config_path, args=cli_args)
|
if config_path_cli:
|
||||||
|
cli_args = filter_arg("config_path", cli_args)
|
||||||
|
cfg = draccus.parse(
|
||||||
|
config_class=argtype,
|
||||||
|
config_path=config_path_cli or config_path,
|
||||||
|
args=cli_args,
|
||||||
|
)
|
||||||
response = fn(cfg, *args, **kwargs)
|
response = fn(cfg, *args, **kwargs)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.parser import (
|
from lerobot.configs.parser import (
|
||||||
_config_path_args,
|
_config_path_args,
|
||||||
_config_yaml_overrides,
|
_config_yaml_overrides,
|
||||||
@@ -16,7 +20,8 @@ from lerobot.configs.parser import (
|
|||||||
|
|
||||||
|
|
||||||
def test_extract_path_fields_from_yaml():
|
def test_extract_path_fields_from_yaml():
|
||||||
"""Test that policy.path is extracted from a YAML config and removed."""
|
"""Test that policy.path is extracted from a YAML config and the policy block
|
||||||
|
is removed entirely (siblings are captured separately as cli_overrides)."""
|
||||||
config = {
|
config = {
|
||||||
"dataset": {"repo_id": "lerobot/pusht"},
|
"dataset": {"repo_id": "lerobot/pusht"},
|
||||||
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
||||||
@@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml():
|
|||||||
config_path = f.name
|
config_path = f.name
|
||||||
|
|
||||||
_config_path_args.clear()
|
_config_path_args.clear()
|
||||||
|
_config_yaml_overrides.clear()
|
||||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||||
|
|
||||||
# Path should be extracted and stored
|
# Path should be extracted and stored
|
||||||
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||||||
|
|
||||||
# Cleaned config should not have the path field
|
# Cleaned config should not have the policy block at all -- draccus must not
|
||||||
|
# try to decode it as PreTrainedConfig; the actual config comes from
|
||||||
|
# from_pretrained(path) with the captured overrides applied on top.
|
||||||
with open(cleaned_path) as f:
|
with open(cleaned_path) as f:
|
||||||
cleaned = yaml.safe_load(f)
|
cleaned = yaml.safe_load(f)
|
||||||
assert "path" not in cleaned["policy"]
|
assert "policy" not in cleaned
|
||||||
assert cleaned["policy"]["type"] == "smolvla"
|
|
||||||
assert cleaned["policy"]["push_to_hub"] is False
|
|
||||||
|
|
||||||
# Original dataset should be untouched
|
# Original dataset should be untouched
|
||||||
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||||||
|
|
||||||
|
# Sibling overrides (excluding type/path) captured for from_pretrained.
|
||||||
|
overrides = get_yaml_overrides("policy")
|
||||||
|
assert any("push_to_hub=false" in o for o in overrides)
|
||||||
|
|
||||||
_config_path_args.clear()
|
_config_path_args.clear()
|
||||||
|
_config_yaml_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
def test_extract_path_fields_from_json():
|
def test_extract_path_fields_from_json():
|
||||||
"""Test that policy.path is extracted from a JSON config."""
|
"""Test that policy.path is extracted from a JSON config and the policy
|
||||||
|
block is removed entirely."""
|
||||||
config = {
|
config = {
|
||||||
"policy": {"type": "act", "path": "some/local/path"},
|
"policy": {"type": "act", "path": "some/local/path"},
|
||||||
}
|
}
|
||||||
@@ -54,15 +66,17 @@ def test_extract_path_fields_from_json():
|
|||||||
config_path = f.name
|
config_path = f.name
|
||||||
|
|
||||||
_config_path_args.clear()
|
_config_path_args.clear()
|
||||||
|
_config_yaml_overrides.clear()
|
||||||
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||||
|
|
||||||
assert _config_path_args["policy"] == "some/local/path"
|
assert _config_path_args["policy"] == "some/local/path"
|
||||||
|
|
||||||
with open(cleaned_path) as f:
|
with open(cleaned_path) as f:
|
||||||
cleaned = json.load(f)
|
cleaned = json.load(f)
|
||||||
assert "path" not in cleaned["policy"]
|
assert "policy" not in cleaned
|
||||||
|
|
||||||
_config_path_args.clear()
|
_config_path_args.clear()
|
||||||
|
_config_yaml_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
def test_extract_no_path_returns_original():
|
def test_extract_no_path_returns_original():
|
||||||
@@ -216,3 +230,91 @@ def test_flatten_nested_with_bools():
|
|||||||
args = _flatten_to_cli_args(d)
|
args = _flatten_to_cli_args(d)
|
||||||
assert "--optimizer.use_warmup=true" in args
|
assert "--optimizer.use_warmup=true" in args
|
||||||
assert "--optimizer.lr=0.01" in args
|
assert "--optimizer.lr=0.01" in args
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_removes_field_with_siblings_and_no_type():
|
||||||
|
"""Regression: when policy.path has siblings but no type:, the entire policy
|
||||||
|
block must still be removed from the cleaned config. Otherwise draccus tries
|
||||||
|
to decode the leftover dict as PreTrainedConfig and crashes on the missing
|
||||||
|
type discriminator.
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"dataset": {"repo_id": "lerobot/pusht"},
|
||||||
|
"policy": {
|
||||||
|
"path": "lerobot/smolvla_base",
|
||||||
|
"n_action_steps": 10,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
|
yaml.dump(config, f)
|
||||||
|
config_path = f.name
|
||||||
|
|
||||||
|
_config_path_args.clear()
|
||||||
|
_config_yaml_overrides.clear()
|
||||||
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
||||||
|
|
||||||
|
with open(cleaned_path) as f:
|
||||||
|
cleaned = yaml.safe_load(f) or {}
|
||||||
|
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
|
||||||
|
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
||||||
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
||||||
|
overrides = get_yaml_overrides("policy")
|
||||||
|
assert any("n_action_steps=10" in o for o in overrides)
|
||||||
|
assert any("dtype=bfloat16" in o for o in overrides)
|
||||||
|
|
||||||
|
_config_path_args.clear()
|
||||||
|
_config_yaml_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _DummyNested:
|
||||||
|
foo: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _DummyConfig:
|
||||||
|
nested: _DummyNested = field(default_factory=_DummyNested)
|
||||||
|
other: str = "default"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_path_fields__(cls):
|
||||||
|
return ["nested"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrap_uses_cleaned_config_for_draccus_parse():
|
||||||
|
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
|
||||||
|
file but must propagate that to the draccus.parse fallback branch. Without
|
||||||
|
the fix, cli_args still contains --config_path=<original> and draccus reads
|
||||||
|
the original YAML with `path:` still in it, crashing on the unknown field.
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"nested": {"path": "some/checkpoint", "foo": 42},
|
||||||
|
"other": "set-via-yaml",
|
||||||
|
}
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
|
yaml.dump(config, f)
|
||||||
|
config_path = f.name
|
||||||
|
|
||||||
|
_config_path_args.clear()
|
||||||
|
_config_yaml_overrides.clear()
|
||||||
|
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
@parser.wrap()
|
||||||
|
def main(cfg: _DummyConfig) -> _DummyConfig:
|
||||||
|
captured["cfg"] = cfg
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
|
||||||
|
main()
|
||||||
|
|
||||||
|
assert captured["cfg"].other == "set-via-yaml"
|
||||||
|
assert _config_path_args["nested"] == "some/checkpoint"
|
||||||
|
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
|
||||||
|
# class (a real PreTrainedConfig would now load the checkpoint and apply
|
||||||
|
# the captured yaml_overrides via from_pretrained()).
|
||||||
|
assert captured["cfg"].nested.foo == 0
|
||||||
|
|
||||||
|
_config_path_args.clear()
|
||||||
|
_config_yaml_overrides.clear()
|
||||||
|
|||||||
Reference in New Issue
Block a user