|
|
|
@@ -1,10 +1,14 @@
|
|
|
|
|
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
import sys
|
|
|
|
|
import tempfile
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
|
|
from lerobot.configs import parser
|
|
|
|
|
from lerobot.configs.parser import (
|
|
|
|
|
_config_path_args,
|
|
|
|
|
_config_yaml_overrides,
|
|
|
|
@@ -16,7 +20,8 @@ from lerobot.configs.parser import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_extract_path_fields_from_yaml():
|
|
|
|
|
"""Test that policy.path is extracted from a YAML config and removed."""
|
|
|
|
|
"""Test that policy.path is extracted from a YAML config and the policy block
|
|
|
|
|
is removed entirely (siblings are captured separately as cli_overrides)."""
|
|
|
|
|
config = {
|
|
|
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
|
|
|
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
|
|
|
@@ -26,26 +31,33 @@ def test_extract_path_fields_from_yaml():
|
|
|
|
|
config_path = f.name
|
|
|
|
|
|
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
|
|
|
|
|
|
# Path should be extracted and stored
|
|
|
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
|
|
|
|
|
|
|
|
|
# Cleaned config should not have the path field
|
|
|
|
|
# Cleaned config should not have the policy block at all -- draccus must not
|
|
|
|
|
# try to decode it as PreTrainedConfig; the actual config comes from
|
|
|
|
|
# from_pretrained(path) with the captured overrides applied on top.
|
|
|
|
|
with open(cleaned_path) as f:
|
|
|
|
|
cleaned = yaml.safe_load(f)
|
|
|
|
|
assert "path" not in cleaned["policy"]
|
|
|
|
|
assert cleaned["policy"]["type"] == "smolvla"
|
|
|
|
|
assert cleaned["policy"]["push_to_hub"] is False
|
|
|
|
|
assert "policy" not in cleaned
|
|
|
|
|
|
|
|
|
|
# Original dataset should be untouched
|
|
|
|
|
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
|
|
|
|
|
|
|
|
|
# Sibling overrides (excluding type/path) captured for from_pretrained.
|
|
|
|
|
overrides = get_yaml_overrides("policy")
|
|
|
|
|
assert any("push_to_hub=false" in o for o in overrides)
|
|
|
|
|
|
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_extract_path_fields_from_json():
|
|
|
|
|
"""Test that policy.path is extracted from a JSON config."""
|
|
|
|
|
"""Test that policy.path is extracted from a JSON config and the policy
|
|
|
|
|
block is removed entirely."""
|
|
|
|
|
config = {
|
|
|
|
|
"policy": {"type": "act", "path": "some/local/path"},
|
|
|
|
|
}
|
|
|
|
@@ -54,15 +66,17 @@ def test_extract_path_fields_from_json():
|
|
|
|
|
config_path = f.name
|
|
|
|
|
|
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
|
|
|
|
|
|
assert _config_path_args["policy"] == "some/local/path"
|
|
|
|
|
|
|
|
|
|
with open(cleaned_path) as f:
|
|
|
|
|
cleaned = json.load(f)
|
|
|
|
|
assert "path" not in cleaned["policy"]
|
|
|
|
|
assert "policy" not in cleaned
|
|
|
|
|
|
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_extract_no_path_returns_original():
|
|
|
|
@@ -216,3 +230,91 @@ def test_flatten_nested_with_bools():
|
|
|
|
|
args = _flatten_to_cli_args(d)
|
|
|
|
|
assert "--optimizer.use_warmup=true" in args
|
|
|
|
|
assert "--optimizer.lr=0.01" in args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_extract_removes_field_with_siblings_and_no_type():
|
|
|
|
|
"""Regression: when policy.path has siblings but no type:, the entire policy
|
|
|
|
|
block must still be removed from the cleaned config. Otherwise draccus tries
|
|
|
|
|
to decode the leftover dict as PreTrainedConfig and crashes on the missing
|
|
|
|
|
type discriminator.
|
|
|
|
|
"""
|
|
|
|
|
config = {
|
|
|
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
|
|
|
"policy": {
|
|
|
|
|
"path": "lerobot/smolvla_base",
|
|
|
|
|
"n_action_steps": 10,
|
|
|
|
|
"dtype": "bfloat16",
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
|
|
|
yaml.dump(config, f)
|
|
|
|
|
config_path = f.name
|
|
|
|
|
|
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
|
|
|
|
|
|
with open(cleaned_path) as f:
|
|
|
|
|
cleaned = yaml.safe_load(f) or {}
|
|
|
|
|
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
|
|
|
|
|
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
|
|
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
|
|
|
|
overrides = get_yaml_overrides("policy")
|
|
|
|
|
assert any("n_action_steps=10" in o for o in overrides)
|
|
|
|
|
assert any("dtype=bfloat16" in o for o in overrides)
|
|
|
|
|
|
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class _DummyNested:
|
|
|
|
|
foo: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class _DummyConfig:
|
|
|
|
|
nested: _DummyNested = field(default_factory=_DummyNested)
|
|
|
|
|
other: str = "default"
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def __get_path_fields__(cls):
|
|
|
|
|
return ["nested"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_wrap_uses_cleaned_config_for_draccus_parse():
|
|
|
|
|
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
|
|
|
|
|
file but must propagate that to the draccus.parse fallback branch. Without
|
|
|
|
|
the fix, cli_args still contains --config_path=<original> and draccus reads
|
|
|
|
|
the original YAML with `path:` still in it, crashing on the unknown field.
|
|
|
|
|
"""
|
|
|
|
|
config = {
|
|
|
|
|
"nested": {"path": "some/checkpoint", "foo": 42},
|
|
|
|
|
"other": "set-via-yaml",
|
|
|
|
|
}
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
|
|
|
yaml.dump(config, f)
|
|
|
|
|
config_path = f.name
|
|
|
|
|
|
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
|
|
|
|
captured: dict = {}
|
|
|
|
|
|
|
|
|
|
@parser.wrap()
|
|
|
|
|
def main(cfg: _DummyConfig) -> _DummyConfig:
|
|
|
|
|
captured["cfg"] = cfg
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
|
|
|
|
|
main()
|
|
|
|
|
|
|
|
|
|
assert captured["cfg"].other == "set-via-yaml"
|
|
|
|
|
assert _config_path_args["nested"] == "some/checkpoint"
|
|
|
|
|
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
|
|
|
|
|
# class (a real PreTrainedConfig would now load the checkpoint and apply
|
|
|
|
|
# the captured yaml_overrides via from_pretrained()).
|
|
|
|
|
assert captured["cfg"].nested.foo == 0
|
|
|
|
|
|
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
_config_yaml_overrides.clear()
|
|
|
|
|