mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
f65f3f7a4a
PR #3145 added YAML support for policy.path but left two bugs: 1. extract_path_fields_from_config only deleted config_data[field] when no sibling overrides existed. With siblings, the dict stayed in place and draccus crashed decoding it as PreTrainedConfig (no 'type' key). Sibling overrides go into _config_yaml_overrides and are applied later by from_pretrained(), so the field can always be removed. 2. wrap() updated config_path_cli to the cleaned temp file path but never propagated it to the draccus.parse fallback branch. cli_args still contained --config_path=<original>, so draccus read the original YAML with path: still present. Tests passed because they (a) called extract_path_fields_from_config directly and (b) included type: alongside path: in the YAML, sidestepping both bugs. Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
321 lines
10 KiB
Python
321 lines
10 KiB
Python
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
|
|
|
import json
|
|
import sys
|
|
import tempfile
|
|
from dataclasses import dataclass, field
|
|
from unittest.mock import patch
|
|
|
|
import yaml
|
|
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.parser import (
|
|
_config_path_args,
|
|
_config_yaml_overrides,
|
|
_flatten_to_cli_args,
|
|
extract_path_fields_from_config,
|
|
get_path_arg,
|
|
get_yaml_overrides,
|
|
)
|
|
|
|
|
|
def test_extract_path_fields_from_yaml():
|
|
"""Test that policy.path is extracted from a YAML config and the policy block
|
|
is removed entirely (siblings are captured separately as cli_overrides)."""
|
|
config = {
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
# Path should be extracted and stored
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
|
|
|
# Cleaned config should not have the policy block at all -- draccus must not
|
|
# try to decode it as PreTrainedConfig; the actual config comes from
|
|
# from_pretrained(path) with the captured overrides applied on top.
|
|
with open(cleaned_path) as f:
|
|
cleaned = yaml.safe_load(f)
|
|
assert "policy" not in cleaned
|
|
|
|
# Original dataset should be untouched
|
|
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
|
|
|
# Sibling overrides (excluding type/path) captured for from_pretrained.
|
|
overrides = get_yaml_overrides("policy")
|
|
assert any("push_to_hub=false" in o for o in overrides)
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
def test_extract_path_fields_from_json():
|
|
"""Test that policy.path is extracted from a JSON config and the policy
|
|
block is removed entirely."""
|
|
config = {
|
|
"policy": {"type": "act", "path": "some/local/path"},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
json.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
assert _config_path_args["policy"] == "some/local/path"
|
|
|
|
with open(cleaned_path) as f:
|
|
cleaned = json.load(f)
|
|
assert "policy" not in cleaned
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
def test_extract_no_path_returns_original():
|
|
"""Test that configs without path fields are returned unchanged."""
|
|
config = {
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
"policy": {"type": "smolvla"},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
result = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
assert result == config_path
|
|
assert "policy" not in _config_path_args
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_extract_removes_empty_field():
|
|
"""Test that the field dict is removed entirely if path was the only key."""
|
|
config = {
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
"policy": {"path": "lerobot/smolvla_base"},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
|
|
|
with open(cleaned_path) as f:
|
|
cleaned = yaml.safe_load(f)
|
|
assert "policy" not in cleaned
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_get_path_arg_fallback():
|
|
"""Test that get_path_arg falls back to _config_path_args when CLI has no path."""
|
|
_config_path_args.clear()
|
|
_config_path_args["policy"] = "lerobot/smolvla_base"
|
|
|
|
# No CLI args with --policy.path
|
|
result = get_path_arg("policy", args=[])
|
|
assert result == "lerobot/smolvla_base"
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_get_path_arg_cli_takes_precedence():
|
|
"""Test that CLI --policy.path takes precedence over YAML config path."""
|
|
_config_path_args.clear()
|
|
_config_path_args["policy"] = "yaml/path"
|
|
|
|
result = get_path_arg("policy", args=["--policy.path=cli/path"])
|
|
assert result == "cli/path"
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_yaml_overrides_captured():
|
|
"""Test that non-path policy fields are captured as CLI-style overrides."""
|
|
config = {
|
|
"policy": {"path": "lerobot/smolvla_base", "lr": 1e-4, "batch_size": 32},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
overrides = get_yaml_overrides("policy")
|
|
assert "--lr=0.0001" in overrides or any("lr=" in o for o in overrides)
|
|
assert any("batch_size=32" in o for o in overrides)
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
def test_yaml_overrides_excludes_type_and_path():
|
|
"""Test that type and path fields are not included in YAML overrides."""
|
|
config = {
|
|
"policy": {"path": "lerobot/smolvla_base", "type": "smolvla", "lr": 5e-5},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
overrides = get_yaml_overrides("policy")
|
|
assert not any("path=" in o for o in overrides)
|
|
assert not any("type=" in o for o in overrides)
|
|
assert any("lr=" in o for o in overrides)
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
def test_get_yaml_overrides_empty_when_path_only():
|
|
"""Test that get_yaml_overrides returns [] when policy had only a path field."""
|
|
config = {
|
|
"policy": {"path": "lerobot/smolvla_base"},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
assert get_yaml_overrides("policy") == []
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
def test_flatten_bool_values():
|
|
"""Test that boolean values are serialized as lowercase strings for draccus."""
|
|
d = {"push_to_hub": True, "use_rabc": False, "lr": 0.001, "name": "test"}
|
|
args = _flatten_to_cli_args(d)
|
|
assert "--push_to_hub=true" in args
|
|
assert "--use_rabc=false" in args
|
|
assert "--lr=0.001" in args
|
|
assert "--name=test" in args
|
|
|
|
|
|
def test_flatten_none_values_skipped():
|
|
"""Test that None values are not included in flattened args."""
|
|
d = {"lr": 0.001, "path_override": None, "name": "test"}
|
|
args = _flatten_to_cli_args(d)
|
|
assert any("lr=" in a for a in args)
|
|
assert any("name=" in a for a in args)
|
|
assert not any("path_override" in a for a in args)
|
|
|
|
|
|
def test_flatten_nested_with_bools():
|
|
"""Test that bools in nested dicts are handled correctly."""
|
|
d = {"optimizer": {"use_warmup": True, "lr": 0.01}}
|
|
args = _flatten_to_cli_args(d)
|
|
assert "--optimizer.use_warmup=true" in args
|
|
assert "--optimizer.lr=0.01" in args
|
|
|
|
|
|
def test_extract_removes_field_with_siblings_and_no_type():
|
|
"""Regression: when policy.path has siblings but no type:, the entire policy
|
|
block must still be removed from the cleaned config. Otherwise draccus tries
|
|
to decode the leftover dict as PreTrainedConfig and crashes on the missing
|
|
type discriminator.
|
|
"""
|
|
config = {
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
"policy": {
|
|
"path": "lerobot/smolvla_base",
|
|
"n_action_steps": 10,
|
|
"dtype": "bfloat16",
|
|
},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
with open(cleaned_path) as f:
|
|
cleaned = yaml.safe_load(f) or {}
|
|
assert "policy" not in cleaned, "policy block should be fully removed when path is present"
|
|
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
|
overrides = get_yaml_overrides("policy")
|
|
assert any("n_action_steps=10" in o for o in overrides)
|
|
assert any("dtype=bfloat16" in o for o in overrides)
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
@dataclass
|
|
class _DummyNested:
|
|
foo: int = 0
|
|
|
|
|
|
@dataclass
|
|
class _DummyConfig:
|
|
nested: _DummyNested = field(default_factory=_DummyNested)
|
|
other: str = "default"
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls):
|
|
return ["nested"]
|
|
|
|
|
|
def test_wrap_uses_cleaned_config_for_draccus_parse():
|
|
"""Regression: wrap() updates config_path_cli to point at the cleaned temp
|
|
file but must propagate that to the draccus.parse fallback branch. Without
|
|
the fix, cli_args still contains --config_path=<original> and draccus reads
|
|
the original YAML with `path:` still in it, crashing on the unknown field.
|
|
"""
|
|
config = {
|
|
"nested": {"path": "some/checkpoint", "foo": 42},
|
|
"other": "set-via-yaml",
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
captured: dict = {}
|
|
|
|
@parser.wrap()
|
|
def main(cfg: _DummyConfig) -> _DummyConfig:
|
|
captured["cfg"] = cfg
|
|
return cfg
|
|
|
|
with patch.object(sys, "argv", ["prog", f"--config_path={config_path}"]):
|
|
main()
|
|
|
|
assert captured["cfg"].other == "set-via-yaml"
|
|
assert _config_path_args["nested"] == "some/checkpoint"
|
|
# Cleaned config dropped `nested:` entirely; defaults stand for this wrapper
|
|
# class (a real PreTrainedConfig would now load the checkpoint and apply
|
|
# the captured yaml_overrides via from_pretrained()).
|
|
assert captured["cfg"].nested.foo == 0
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|