Files
lerobot/tests/test_yaml_policy_path.py
T
Reece O'Mahoney f65f3f7a4a Fix policy.path in YAML configs (PR #3145 followup) (#3597)
PR #3145 added YAML support for policy.path but left two bugs:

1. extract_path_fields_from_config only deleted config_data[field] when
   no sibling overrides existed. With siblings, the dict stayed in place
   and draccus crashed decoding it as PreTrainedConfig (no 'type' key).
   Sibling overrides go into _config_yaml_overrides and are applied later
   by from_pretrained(), so the field can always be removed.

2. wrap() updated config_path_cli to the cleaned temp file path but
   never propagated it to the draccus.parse fallback branch. cli_args
   still contained --config_path=<original>, so draccus read the
   original YAML with path: still present.

Tests passed because they (a) called extract_path_fields_from_config
directly and (b) included type: alongside path: in the YAML, sidestepping
both bugs.

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-26 14:01:19 +02:00

321 lines
10 KiB
Python

"""Tests for policy.path support in YAML config files (issue #2957)."""
import json
import sys
import tempfile
from dataclasses import dataclass, field
from unittest.mock import patch
import yaml
from lerobot.configs import parser
from lerobot.configs.parser import (
_config_path_args,
_config_yaml_overrides,
_flatten_to_cli_args,
extract_path_fields_from_config,
get_path_arg,
get_yaml_overrides,
)
def test_extract_path_fields_from_yaml():
"""Test that policy.path is extracted from a YAML config and the policy block
is removed entirely (siblings are captured separately as cli_overrides)."""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
# Path should be extracted and stored
assert _config_path_args["policy"] == "lerobot/smolvla_base"
# Cleaned config should not have the policy block at all -- draccus must not
# try to decode it as PreTrainedConfig; the actual config comes from
# from_pretrained(path) with the captured overrides applied on top.
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f)
assert "policy" not in cleaned
# Original dataset should be untouched
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
# Sibling overrides (excluding type/path) captured for from_pretrained.
overrides = get_yaml_overrides("policy")
assert any("push_to_hub=false" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_extract_path_fields_from_json():
"""Test that policy.path is extracted from a JSON config and the policy
block is removed entirely."""
config = {
"policy": {"type": "act", "path": "some/local/path"},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
assert _config_path_args["policy"] == "some/local/path"
with open(cleaned_path) as f:
cleaned = json.load(f)
assert "policy" not in cleaned
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_extract_no_path_returns_original():
"""Test that configs without path fields are returned unchanged."""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {"type": "smolvla"},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
result = extract_path_fields_from_config(config_path, ["policy"])
assert result == config_path
assert "policy" not in _config_path_args
_config_path_args.clear()
def test_extract_removes_empty_field():
"""Test that the field dict is removed entirely if path was the only key."""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {"path": "lerobot/smolvla_base"},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
assert _config_path_args["policy"] == "lerobot/smolvla_base"
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f)
assert "policy" not in cleaned
_config_path_args.clear()
def test_get_path_arg_fallback():
"""Test that get_path_arg falls back to _config_path_args when CLI has no path."""
_config_path_args.clear()
_config_path_args["policy"] = "lerobot/smolvla_base"
# No CLI args with --policy.path
result = get_path_arg("policy", args=[])
assert result == "lerobot/smolvla_base"
_config_path_args.clear()
def test_get_path_arg_cli_takes_precedence():
"""Test that CLI --policy.path takes precedence over YAML config path."""
_config_path_args.clear()
_config_path_args["policy"] = "yaml/path"
result = get_path_arg("policy", args=["--policy.path=cli/path"])
assert result == "cli/path"
_config_path_args.clear()
def test_yaml_overrides_captured():
"""Test that non-path policy fields are captured as CLI-style overrides."""
config = {
"policy": {"path": "lerobot/smolvla_base", "lr": 1e-4, "batch_size": 32},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
extract_path_fields_from_config(config_path, ["policy"])
overrides = get_yaml_overrides("policy")
assert "--lr=0.0001" in overrides or any("lr=" in o for o in overrides)
assert any("batch_size=32" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_yaml_overrides_excludes_type_and_path():
"""Test that type and path fields are not included in YAML overrides."""
config = {
"policy": {"path": "lerobot/smolvla_base", "type": "smolvla", "lr": 5e-5},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
extract_path_fields_from_config(config_path, ["policy"])
overrides = get_yaml_overrides("policy")
assert not any("path=" in o for o in overrides)
assert not any("type=" in o for o in overrides)
assert any("lr=" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_get_yaml_overrides_empty_when_path_only():
"""Test that get_yaml_overrides returns [] when policy had only a path field."""
config = {
"policy": {"path": "lerobot/smolvla_base"},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
extract_path_fields_from_config(config_path, ["policy"])
assert get_yaml_overrides("policy") == []
_config_path_args.clear()
_config_yaml_overrides.clear()
def test_flatten_bool_values():
"""Test that boolean values are serialized as lowercase strings for draccus."""
d = {"push_to_hub": True, "use_rabc": False, "lr": 0.001, "name": "test"}
args = _flatten_to_cli_args(d)
assert "--push_to_hub=true" in args
assert "--use_rabc=false" in args
assert "--lr=0.001" in args
assert "--name=test" in args
def test_flatten_none_values_skipped():
"""Test that None values are not included in flattened args."""
d = {"lr": 0.001, "path_override": None, "name": "test"}
args = _flatten_to_cli_args(d)
assert any("lr=" in a for a in args)
assert any("name=" in a for a in args)
assert not any("path_override" in a for a in args)
def test_flatten_nested_with_bools():
"""Test that bools in nested dicts are handled correctly."""
d = {"optimizer": {"use_warmup": True, "lr": 0.01}}
args = _flatten_to_cli_args(d)
assert "--optimizer.use_warmup=true" in args
assert "--optimizer.lr=0.01" in args
def test_extract_removes_field_with_siblings_and_no_type():
"""Regression: when policy.path has siblings but no type:, the entire policy
block must still be removed from the cleaned config. Otherwise draccus tries
to decode the leftover dict as PreTrainedConfig and crashes on the missing
type discriminator.
"""
config = {
"dataset": {"repo_id": "lerobot/pusht"},
"policy": {
"path": "lerobot/smolvla_base",
"n_action_steps": 10,
"dtype": "bfloat16",
},
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
with open(cleaned_path) as f:
cleaned = yaml.safe_load(f) or {}
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
assert _config_path_args["policy"] == "lerobot/smolvla_base"
overrides = get_yaml_overrides("policy")
assert any("n_action_steps=10" in o for o in overrides)
assert any("dtype=bfloat16" in o for o in overrides)
_config_path_args.clear()
_config_yaml_overrides.clear()
@dataclass
class _DummyNested:
foo: int = 0
@dataclass
class _DummyConfig:
nested: _DummyNested = field(default_factory=_DummyNested)
other: str = "default"
@classmethod
def __get_path_fields__(cls):
return ["nested"]
def test_wrap_uses_cleaned_config_for_draccus_parse():
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
file but must propagate that to the draccus.parse fallback branch. Without
the fix, cli_args still contains --config_path=<original> and draccus reads
the original YAML with `path:` still in it, crashing on the unknown field.
"""
config = {
"nested": {"path": "some/checkpoint", "foo": 42},
"other": "set-via-yaml",
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
yaml.dump(config, f)
config_path = f.name
_config_path_args.clear()
_config_yaml_overrides.clear()
captured: dict = {}
@parser.wrap()
def main(cfg: _DummyConfig) -> _DummyConfig:
captured["cfg"] = cfg
return cfg
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
main()
assert captured["cfg"].other == "set-via-yaml"
assert _config_path_args["nested"] == "some/checkpoint"
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
# class (a real PreTrainedConfig would now load the checkpoint and apply
# the captured yaml_overrides via from_pretrained()).
assert captured["cfg"].nested.foo == 0
_config_path_args.clear()
_config_yaml_overrides.clear()