mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
fe96b28c74
* fix(config): support policy.path in YAML config files policy.path was only handled via CLI args (filtered from sys.argv before draccus, then retrieved in validate()). When specified in YAML, draccus would crash because 'path' is not a valid field on PreTrainedConfig. Extract path fields from the YAML/JSON config before draccus processes it, store them in a module-level dict, and fall back to it in get_path_arg() when the CLI doesn't have the path. Fixes #2957 * fix(parser): preserve YAML policy overrides when loading from pretrained When policy.path is set in YAML, validate() was calling from_pretrained with only CLI overrides, discarding any YAML policy fields (e.g. lr, batch_size) that draccus had already parsed. Fix by capturing the remaining YAML fields as CLI-style args in _config_yaml_overrides and merging them into the overrides passed to from_pretrained in train.py, eval.py, and lerobot_record.py (CLI args still take precedence). Also fix the NamedTemporaryFile SIM115 ruff warning and add types-PyYAML to the mypy pre-commit hook. * fix(parser): serialize bool/None values correctly in YAML policy overrides Bool values from YAML configs (e.g. push_to_hub: true) were passed as Python "True"/"False" strings instead of lowercase "true"/"false" that draccus expects. Also skip None values to avoid passing "None" strings. * revert: remove types-PyYAML from .pre-commit-config.yaml * chore: fix quality check caused by untyped YAML import Co-authored-by: masato-ka <jp6uzv@gmail.com> Signed-off-by: Khalil Meftah <khalil.meftah@huggingface.co> --------- Signed-off-by: Khalil Meftah <khalil.meftah@huggingface.co> Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co> Co-authored-by: masato-ka <jp6uzv@gmail.com>
219 lines
6.9 KiB
Python
219 lines
6.9 KiB
Python
"""Tests for policy.path support in YAML config files (issue #2957)."""
|
|
|
|
import json
|
|
import tempfile
|
|
|
|
import yaml
|
|
|
|
from lerobot.configs.parser import (
|
|
_config_path_args,
|
|
_config_yaml_overrides,
|
|
_flatten_to_cli_args,
|
|
extract_path_fields_from_config,
|
|
get_path_arg,
|
|
get_yaml_overrides,
|
|
)
|
|
|
|
|
|
def test_extract_path_fields_from_yaml():
|
|
"""Test that policy.path is extracted from a YAML config and removed."""
|
|
config = {
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
"policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
# Path should be extracted and stored
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
|
|
|
# Cleaned config should not have the path field
|
|
with open(cleaned_path) as f:
|
|
cleaned = yaml.safe_load(f)
|
|
assert "path" not in cleaned["policy"]
|
|
assert cleaned["policy"]["type"] == "smolvla"
|
|
assert cleaned["policy"]["push_to_hub"] is False
|
|
|
|
# Original dataset should be untouched
|
|
assert cleaned["dataset"]["repo_id"] == "lerobot/pusht"
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_extract_path_fields_from_json():
|
|
"""Test that policy.path is extracted from a JSON config."""
|
|
config = {
|
|
"policy": {"type": "act", "path": "some/local/path"},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
json.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
assert _config_path_args["policy"] == "some/local/path"
|
|
|
|
with open(cleaned_path) as f:
|
|
cleaned = json.load(f)
|
|
assert "path" not in cleaned["policy"]
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_extract_no_path_returns_original():
|
|
"""Test that configs without path fields are returned unchanged."""
|
|
config = {
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
"policy": {"type": "smolvla"},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
result = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
assert result == config_path
|
|
assert "policy" not in _config_path_args
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_extract_removes_empty_field():
|
|
"""Test that the field dict is removed entirely if path was the only key."""
|
|
config = {
|
|
"dataset": {"repo_id": "lerobot/pusht"},
|
|
"policy": {"path": "lerobot/smolvla_base"},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
cleaned_path = extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
assert _config_path_args["policy"] == "lerobot/smolvla_base"
|
|
|
|
with open(cleaned_path) as f:
|
|
cleaned = yaml.safe_load(f)
|
|
assert "policy" not in cleaned
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_get_path_arg_fallback():
|
|
"""Test that get_path_arg falls back to _config_path_args when CLI has no path."""
|
|
_config_path_args.clear()
|
|
_config_path_args["policy"] = "lerobot/smolvla_base"
|
|
|
|
# No CLI args with --policy.path
|
|
result = get_path_arg("policy", args=[])
|
|
assert result == "lerobot/smolvla_base"
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_get_path_arg_cli_takes_precedence():
|
|
"""Test that CLI --policy.path takes precedence over YAML config path."""
|
|
_config_path_args.clear()
|
|
_config_path_args["policy"] = "yaml/path"
|
|
|
|
result = get_path_arg("policy", args=["--policy.path=cli/path"])
|
|
assert result == "cli/path"
|
|
|
|
_config_path_args.clear()
|
|
|
|
|
|
def test_yaml_overrides_captured():
|
|
"""Test that non-path policy fields are captured as CLI-style overrides."""
|
|
config = {
|
|
"policy": {"path": "lerobot/smolvla_base", "lr": 1e-4, "batch_size": 32},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
overrides = get_yaml_overrides("policy")
|
|
assert "--lr=0.0001" in overrides or any("lr=" in o for o in overrides)
|
|
assert any("batch_size=32" in o for o in overrides)
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
def test_yaml_overrides_excludes_type_and_path():
|
|
"""Test that type and path fields are not included in YAML overrides."""
|
|
config = {
|
|
"policy": {"path": "lerobot/smolvla_base", "type": "smolvla", "lr": 5e-5},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
overrides = get_yaml_overrides("policy")
|
|
assert not any("path=" in o for o in overrides)
|
|
assert not any("type=" in o for o in overrides)
|
|
assert any("lr=" in o for o in overrides)
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
def test_get_yaml_overrides_empty_when_path_only():
|
|
"""Test that get_yaml_overrides returns [] when policy had only a path field."""
|
|
config = {
|
|
"policy": {"path": "lerobot/smolvla_base"},
|
|
}
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
|
yaml.dump(config, f)
|
|
config_path = f.name
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
extract_path_fields_from_config(config_path, ["policy"])
|
|
|
|
assert get_yaml_overrides("policy") == []
|
|
|
|
_config_path_args.clear()
|
|
_config_yaml_overrides.clear()
|
|
|
|
|
|
def test_flatten_bool_values():
|
|
"""Test that boolean values are serialized as lowercase strings for draccus."""
|
|
d = {"push_to_hub": True, "use_rabc": False, "lr": 0.001, "name": "test"}
|
|
args = _flatten_to_cli_args(d)
|
|
assert "--push_to_hub=true" in args
|
|
assert "--use_rabc=false" in args
|
|
assert "--lr=0.001" in args
|
|
assert "--name=test" in args
|
|
|
|
|
|
def test_flatten_none_values_skipped():
|
|
"""Test that None values are not included in flattened args."""
|
|
d = {"lr": 0.001, "path_override": None, "name": "test"}
|
|
args = _flatten_to_cli_args(d)
|
|
assert any("lr=" in a for a in args)
|
|
assert any("name=" in a for a in args)
|
|
assert not any("path_override" in a for a in args)
|
|
|
|
|
|
def test_flatten_nested_with_bools():
|
|
"""Test that bools in nested dicts are handled correctly."""
|
|
d = {"optimizer": {"use_warmup": True, "lr": 0.01}}
|
|
args = _flatten_to_cli_args(d)
|
|
assert "--optimizer.use_warmup=true" in args
|
|
assert "--optimizer.lr=0.01" in args
|