From fe96b28c7457629588dd113df11bb853981587f4 Mon Sep 17 00:00:00 2001 From: Jash Shah Date: Wed, 13 May 2026 00:45:27 -0700 Subject: [PATCH] Fix policy.path not working in YAML config files (#3145) * fix(config): support policy.path in YAML config files policy.path was only handled via CLI args (filtered from sys.argv before draccus, then retrieved in validate()). When specified in YAML, draccus would crash because 'path' is not a valid field on PreTrainedConfig. Extract path fields from the YAML/JSON config before draccus processes it, store them in a module-level dict, and fall back to it in get_path_arg() when the CLI doesn't have the path. Fixes #2957 * fix(parser): preserve YAML policy overrides when loading from pretrained When policy.path is set in YAML, validate() was calling from_pretrained with only CLI overrides, discarding any YAML policy fields (e.g. lr, batch_size) that draccus had already parsed. Fix by capturing the remaining YAML fields as CLI-style args in _config_yaml_overrides and merging them into the overrides passed to from_pretrained in train.py, eval.py, and lerobot_record.py (CLI args still take precedence). Also fix the NamedTemporaryFile SIM115 ruff warning and add types-PyYAML to the mypy pre-commit hook. * fix(parser): serialize bool/None values correctly in YAML policy overrides Bool values from YAML configs (e.g. push_to_hub: true) were passed as Python "True"/"False" strings instead of lowercase "true"/"false" that draccus expects. Also skip None values to avoid passing "None" strings. * revert: remove types-PyYAML from .pre-commit-config.yaml * chore: fix quality check caused by untyped YAML import Co-authored-by: masato-ka Signed-off-by: Khalil Meftah --------- Signed-off-by: Khalil Meftah Co-authored-by: Khalil Meftah Co-authored-by: masato-ka --- src/lerobot/configs/eval.py | 7 +- src/lerobot/configs/parser.py | 84 ++++++++++++- src/lerobot/configs/train.py | 7 +- tests/test_yaml_policy_path.py | 218 +++++++++++++++++++++++++++++++++ 4 files changed, 311 insertions(+), 5 deletions(-) create mode 100644 tests/test_yaml_policy_path.py diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py index d1cebd27f..f2a1d3065 100644 --- a/src/lerobot/configs/eval.py +++ b/src/lerobot/configs/eval.py @@ -46,8 +46,11 @@ class EvalPipelineConfig: # HACK: We parse again the cli args here to get the pretrained path if there was one. policy_path = parser.get_path_arg("policy") if policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + yaml_overrides = parser.get_yaml_overrides("policy") + cli_overrides = parser.get_cli_overrides("policy") or [] + self.policy = PreTrainedConfig.from_pretrained( + policy_path, cli_overrides=yaml_overrides + cli_overrides + ) self.policy.pretrained_path = Path(policy_path) else: diff --git a/src/lerobot/configs/parser.py b/src/lerobot/configs/parser.py index 57ebaf8fa..d55fa44aa 100644 --- a/src/lerobot/configs/parser.py +++ b/src/lerobot/configs/parser.py @@ -13,8 +13,10 @@ # limitations under the License. import importlib import inspect +import json import pkgutil import sys +import tempfile from argparse import ArgumentError from collections.abc import Callable, Iterable, Sequence from functools import wraps @@ -24,6 +26,7 @@ from types import ModuleType from typing import Any, TypeVar, cast import draccus +import yaml # type: ignore[import-untyped] from lerobot.utils.utils import has_method @@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object]) PATH_KEY = "path" PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" +# Storage for path args extracted from YAML/JSON config files, so that +# get_path_arg() can find them even when they weren't passed via CLI. +_config_path_args: dict[str, str] = {} + +# Storage for non-path YAML overrides so validate() can pass them to from_pretrained. +_config_yaml_overrides: dict[str, list[str]] = {} + + +def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]: + """Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"]).""" + args = [] + for key, value in d.items(): + if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY): + continue + full_key = f"{prefix}.{key}" if prefix else key + if isinstance(value, bool): + value = str(value).lower() + if isinstance(value, dict): + args.extend(_flatten_to_cli_args(value, full_key)) + elif value is not None and not isinstance(value, list): + args.append(f"--{full_key}={value}") + return args + def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None: """Parses arguments from cli at a given nested attribute level. @@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None: def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: - return parse_arg(f"{field_name}.{PATH_KEY}", args) + result = parse_arg(f"{field_name}.{PATH_KEY}", args) + if result is None: + result = _config_path_args.get(field_name) + return result + + +def get_yaml_overrides(field_name: str) -> list[str]: + return _config_yaml_overrides.get(field_name, []) def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: @@ -192,6 +225,52 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No return filtered_args +def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str: + """Extract `path` fields from a YAML/JSON config before draccus processes it. + + When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config, + draccus will fail because ``path`` is not a valid field on policy config classes. + This function extracts those path values, stores them in ``_config_path_args`` for + later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path. + """ + config_file = Path(config_path) + suffix = config_file.suffix.lower() + + if suffix in (".yaml", ".yml"): + with open(config_file) as f: + config_data = yaml.safe_load(f) + elif suffix == ".json": + with open(config_file) as f: + config_data = json.load(f) + else: + return config_path + + if not isinstance(config_data, dict): + return config_path + + modified = False + for field in path_fields: + if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]: + _config_path_args[field] = str(config_data[field].pop(PATH_KEY)) + remaining = config_data[field] + if remaining: + _config_yaml_overrides[field] = _flatten_to_cli_args(remaining) + else: + del config_data[field] + modified = True + + if not modified: + return config_path + + # Write cleaned config to a temp file + with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp: + if suffix in (".yaml", ".yml"): + yaml.dump(config_data, tmp, default_flow_style=False) + else: + json.dump(config_data, tmp, indent=2) + return tmp.name + + def wrap(config_path: Path | None = None) -> Callable[[F], F]: """ HACK: Similar to draccus.wrap but does three additional things: @@ -225,6 +304,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]: if has_method(argtype, "__get_path_fields__"): path_fields = argtype.__get_path_fields__() cli_args = filter_path_args(path_fields, cli_args) + # Also extract path fields from the YAML/JSON config file + if config_path_cli: + config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields) if has_method(argtype, "from_pretrained") and config_path_cli: cli_args = filter_arg("config_path", cli_args) cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 388de9437..c5b4ff5f5 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -144,8 +144,11 @@ class TrainPipelineConfig(HubMixin): ) self.reward_model.pretrained_path = str(Path(reward_model_path)) elif policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + yaml_overrides = parser.get_yaml_overrides("policy") + cli_overrides = parser.get_cli_overrides("policy") or [] + self.policy = PreTrainedConfig.from_pretrained( + policy_path, cli_overrides=yaml_overrides + cli_overrides + ) self.policy.pretrained_path = Path(policy_path) elif self.resume: config_path = parser.parse_arg("config_path") diff --git a/tests/test_yaml_policy_path.py b/tests/test_yaml_policy_path.py new file mode 100644 index 000000000..710a71c9a --- /dev/null +++ b/tests/test_yaml_policy_path.py @@ -0,0 +1,218 @@ +"""Tests for policy.path support in YAML config files (issue #2957).""" + +import json +import tempfile + +import yaml + +from lerobot.configs.parser import ( + _config_path_args, + _config_yaml_overrides, + _flatten_to_cli_args, + extract_path_fields_from_config, + get_path_arg, + get_yaml_overrides, +) + + +def test_extract_path_fields_from_yaml(): + """Test that policy.path is extracted from a YAML config and removed.""" + config = { + "dataset": {"repo_id": "lerobot/pusht"}, + "policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) + + # Path should be extracted and stored + assert _config_path_args["policy"] == "lerobot/smolvla_base" + + # Cleaned config should not have the path field + with open(cleaned_path) as f: + cleaned = yaml.safe_load(f) + assert "path" not in cleaned["policy"] + assert cleaned["policy"]["type"] == "smolvla" + assert cleaned["policy"]["push_to_hub"] is False + + # Original dataset should be untouched + assert cleaned["dataset"]["repo_id"] == "lerobot/pusht" + + _config_path_args.clear() + + +def test_extract_path_fields_from_json(): + """Test that policy.path is extracted from a JSON config.""" + config = { + "policy": {"type": "act", "path": "some/local/path"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(config, f) + config_path = f.name + + _config_path_args.clear() + cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) + + assert _config_path_args["policy"] == "some/local/path" + + with open(cleaned_path) as f: + cleaned = json.load(f) + assert "path" not in cleaned["policy"] + + _config_path_args.clear() + + +def test_extract_no_path_returns_original(): + """Test that configs without path fields are returned unchanged.""" + config = { + "dataset": {"repo_id": "lerobot/pusht"}, + "policy": {"type": "smolvla"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + result = extract_path_fields_from_config(config_path, ["policy"]) + + assert result == config_path + assert "policy" not in _config_path_args + + _config_path_args.clear() + + +def test_extract_removes_empty_field(): + """Test that the field dict is removed entirely if path was the only key.""" + config = { + "dataset": {"repo_id": "lerobot/pusht"}, + "policy": {"path": "lerobot/smolvla_base"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) + + assert _config_path_args["policy"] == "lerobot/smolvla_base" + + with open(cleaned_path) as f: + cleaned = yaml.safe_load(f) + assert "policy" not in cleaned + + _config_path_args.clear() + + +def test_get_path_arg_fallback(): + """Test that get_path_arg falls back to _config_path_args when CLI has no path.""" + _config_path_args.clear() + _config_path_args["policy"] = "lerobot/smolvla_base" + + # No CLI args with --policy.path + result = get_path_arg("policy", args=[]) + assert result == "lerobot/smolvla_base" + + _config_path_args.clear() + + +def test_get_path_arg_cli_takes_precedence(): + """Test that CLI --policy.path takes precedence over YAML config path.""" + _config_path_args.clear() + _config_path_args["policy"] = "yaml/path" + + result = get_path_arg("policy", args=["--policy.path=cli/path"]) + assert result == "cli/path" + + _config_path_args.clear() + + +def test_yaml_overrides_captured(): + """Test that non-path policy fields are captured as CLI-style overrides.""" + config = { + "policy": {"path": "lerobot/smolvla_base", "lr": 1e-4, "batch_size": 32}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + extract_path_fields_from_config(config_path, ["policy"]) + + overrides = get_yaml_overrides("policy") + assert "--lr=0.0001" in overrides or any("lr=" in o for o in overrides) + assert any("batch_size=32" in o for o in overrides) + + _config_path_args.clear() + _config_yaml_overrides.clear() + + +def test_yaml_overrides_excludes_type_and_path(): + """Test that type and path fields are not included in YAML overrides.""" + config = { + "policy": {"path": "lerobot/smolvla_base", "type": "smolvla", "lr": 5e-5}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + extract_path_fields_from_config(config_path, ["policy"]) + + overrides = get_yaml_overrides("policy") + assert not any("path=" in o for o in overrides) + assert not any("type=" in o for o in overrides) + assert any("lr=" in o for o in overrides) + + _config_path_args.clear() + _config_yaml_overrides.clear() + + +def test_get_yaml_overrides_empty_when_path_only(): + """Test that get_yaml_overrides returns [] when policy had only a path field.""" + config = { + "policy": {"path": "lerobot/smolvla_base"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + extract_path_fields_from_config(config_path, ["policy"]) + + assert get_yaml_overrides("policy") == [] + + _config_path_args.clear() + _config_yaml_overrides.clear() + + +def test_flatten_bool_values(): + """Test that boolean values are serialized as lowercase strings for draccus.""" + d = {"push_to_hub": True, "use_rabc": False, "lr": 0.001, "name": "test"} + args = _flatten_to_cli_args(d) + assert "--push_to_hub=true" in args + assert "--use_rabc=false" in args + assert "--lr=0.001" in args + assert "--name=test" in args + + +def test_flatten_none_values_skipped(): + """Test that None values are not included in flattened args.""" + d = {"lr": 0.001, "path_override": None, "name": "test"} + args = _flatten_to_cli_args(d) + assert any("lr=" in a for a in args) + assert any("name=" in a for a in args) + assert not any("path_override" in a for a in args) + + +def test_flatten_nested_with_bools(): + """Test that bools in nested dicts are handled correctly.""" + d = {"optimizer": {"use_warmup": True, "lr": 0.01}} + args = _flatten_to_cli_args(d) + assert "--optimizer.use_warmup=true" in args + assert "--optimizer.lr=0.01" in args