diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py index d1cebd27f..f2a1d3065 100644 --- a/src/lerobot/configs/eval.py +++ b/src/lerobot/configs/eval.py @@ -46,8 +46,11 @@ class EvalPipelineConfig: # HACK: We parse again the cli args here to get the pretrained path if there was one. policy_path = parser.get_path_arg("policy") if policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + yaml_overrides = parser.get_yaml_overrides("policy") + cli_overrides = parser.get_cli_overrides("policy") or [] + self.policy = PreTrainedConfig.from_pretrained( + policy_path, cli_overrides=yaml_overrides + cli_overrides + ) self.policy.pretrained_path = Path(policy_path) else: diff --git a/src/lerobot/configs/parser.py b/src/lerobot/configs/parser.py index 57ebaf8fa..d55fa44aa 100644 --- a/src/lerobot/configs/parser.py +++ b/src/lerobot/configs/parser.py @@ -13,8 +13,10 @@ # limitations under the License. import importlib import inspect +import json import pkgutil import sys +import tempfile from argparse import ArgumentError from collections.abc import Callable, Iterable, Sequence from functools import wraps @@ -24,6 +26,7 @@ from types import ModuleType from typing import Any, TypeVar, cast import draccus +import yaml # type: ignore[import-untyped] from lerobot.utils.utils import has_method @@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object]) PATH_KEY = "path" PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" +# Storage for path args extracted from YAML/JSON config files, so that +# get_path_arg() can find them even when they weren't passed via CLI. +_config_path_args: dict[str, str] = {} + +# Storage for non-path YAML overrides so validate() can pass them to from_pretrained. +_config_yaml_overrides: dict[str, list[str]] = {} + + +def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]: + """Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"]).""" + args = [] + for key, value in d.items(): + if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY): + continue + full_key = f"{prefix}.{key}" if prefix else key + if isinstance(value, bool): + value = str(value).lower() + if isinstance(value, dict): + args.extend(_flatten_to_cli_args(value, full_key)) + elif value is not None and not isinstance(value, list): + args.append(f"--{full_key}={value}") + return args + def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None: """Parses arguments from cli at a given nested attribute level. @@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None: def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: - return parse_arg(f"{field_name}.{PATH_KEY}", args) + result = parse_arg(f"{field_name}.{PATH_KEY}", args) + if result is None: + result = _config_path_args.get(field_name) + return result + + +def get_yaml_overrides(field_name: str) -> list[str]: + return _config_yaml_overrides.get(field_name, []) def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None: @@ -192,6 +225,52 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No return filtered_args +def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str: + """Extract `path` fields from a YAML/JSON config before draccus processes it. + + When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config, + draccus will fail because ``path`` is not a valid field on policy config classes. + This function extracts those path values, stores them in ``_config_path_args`` for + later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path. + """ + config_file = Path(config_path) + suffix = config_file.suffix.lower() + + if suffix in (".yaml", ".yml"): + with open(config_file) as f: + config_data = yaml.safe_load(f) + elif suffix == ".json": + with open(config_file) as f: + config_data = json.load(f) + else: + return config_path + + if not isinstance(config_data, dict): + return config_path + + modified = False + for field in path_fields: + if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]: + _config_path_args[field] = str(config_data[field].pop(PATH_KEY)) + remaining = config_data[field] + if remaining: + _config_yaml_overrides[field] = _flatten_to_cli_args(remaining) + else: + del config_data[field] + modified = True + + if not modified: + return config_path + + # Write cleaned config to a temp file + with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp: + if suffix in (".yaml", ".yml"): + yaml.dump(config_data, tmp, default_flow_style=False) + else: + json.dump(config_data, tmp, indent=2) + return tmp.name + + def wrap(config_path: Path | None = None) -> Callable[[F], F]: """ HACK: Similar to draccus.wrap but does three additional things: @@ -225,6 +304,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]: if has_method(argtype, "__get_path_fields__"): path_fields = argtype.__get_path_fields__() cli_args = filter_path_args(path_fields, cli_args) + # Also extract path fields from the YAML/JSON config file + if config_path_cli: + config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields) if has_method(argtype, "from_pretrained") and config_path_cli: cli_args = filter_arg("config_path", cli_args) cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 388de9437..c5b4ff5f5 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -144,8 +144,11 @@ class TrainPipelineConfig(HubMixin): ) self.reward_model.pretrained_path = str(Path(reward_model_path)) elif policy_path: - cli_overrides = parser.get_cli_overrides("policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + yaml_overrides = parser.get_yaml_overrides("policy") + cli_overrides = parser.get_cli_overrides("policy") or [] + self.policy = PreTrainedConfig.from_pretrained( + policy_path, cli_overrides=yaml_overrides + cli_overrides + ) self.policy.pretrained_path = Path(policy_path) elif self.resume: config_path = parser.parse_arg("config_path") diff --git a/tests/test_yaml_policy_path.py b/tests/test_yaml_policy_path.py new file mode 100644 index 000000000..710a71c9a --- /dev/null +++ b/tests/test_yaml_policy_path.py @@ -0,0 +1,218 @@ +"""Tests for policy.path support in YAML config files (issue #2957).""" + +import json +import tempfile + +import yaml + +from lerobot.configs.parser import ( + _config_path_args, + _config_yaml_overrides, + _flatten_to_cli_args, + extract_path_fields_from_config, + get_path_arg, + get_yaml_overrides, +) + + +def test_extract_path_fields_from_yaml(): + """Test that policy.path is extracted from a YAML config and removed.""" + config = { + "dataset": {"repo_id": "lerobot/pusht"}, + "policy": {"type": "smolvla", "path": "lerobot/smolvla_base", "push_to_hub": False}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) + + # Path should be extracted and stored + assert _config_path_args["policy"] == "lerobot/smolvla_base" + + # Cleaned config should not have the path field + with open(cleaned_path) as f: + cleaned = yaml.safe_load(f) + assert "path" not in cleaned["policy"] + assert cleaned["policy"]["type"] == "smolvla" + assert cleaned["policy"]["push_to_hub"] is False + + # Original dataset should be untouched + assert cleaned["dataset"]["repo_id"] == "lerobot/pusht" + + _config_path_args.clear() + + +def test_extract_path_fields_from_json(): + """Test that policy.path is extracted from a JSON config.""" + config = { + "policy": {"type": "act", "path": "some/local/path"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(config, f) + config_path = f.name + + _config_path_args.clear() + cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) + + assert _config_path_args["policy"] == "some/local/path" + + with open(cleaned_path) as f: + cleaned = json.load(f) + assert "path" not in cleaned["policy"] + + _config_path_args.clear() + + +def test_extract_no_path_returns_original(): + """Test that configs without path fields are returned unchanged.""" + config = { + "dataset": {"repo_id": "lerobot/pusht"}, + "policy": {"type": "smolvla"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + result = extract_path_fields_from_config(config_path, ["policy"]) + + assert result == config_path + assert "policy" not in _config_path_args + + _config_path_args.clear() + + +def test_extract_removes_empty_field(): + """Test that the field dict is removed entirely if path was the only key.""" + config = { + "dataset": {"repo_id": "lerobot/pusht"}, + "policy": {"path": "lerobot/smolvla_base"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + cleaned_path = extract_path_fields_from_config(config_path, ["policy"]) + + assert _config_path_args["policy"] == "lerobot/smolvla_base" + + with open(cleaned_path) as f: + cleaned = yaml.safe_load(f) + assert "policy" not in cleaned + + _config_path_args.clear() + + +def test_get_path_arg_fallback(): + """Test that get_path_arg falls back to _config_path_args when CLI has no path.""" + _config_path_args.clear() + _config_path_args["policy"] = "lerobot/smolvla_base" + + # No CLI args with --policy.path + result = get_path_arg("policy", args=[]) + assert result == "lerobot/smolvla_base" + + _config_path_args.clear() + + +def test_get_path_arg_cli_takes_precedence(): + """Test that CLI --policy.path takes precedence over YAML config path.""" + _config_path_args.clear() + _config_path_args["policy"] = "yaml/path" + + result = get_path_arg("policy", args=["--policy.path=cli/path"]) + assert result == "cli/path" + + _config_path_args.clear() + + +def test_yaml_overrides_captured(): + """Test that non-path policy fields are captured as CLI-style overrides.""" + config = { + "policy": {"path": "lerobot/smolvla_base", "lr": 1e-4, "batch_size": 32}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + extract_path_fields_from_config(config_path, ["policy"]) + + overrides = get_yaml_overrides("policy") + assert "--lr=0.0001" in overrides or any("lr=" in o for o in overrides) + assert any("batch_size=32" in o for o in overrides) + + _config_path_args.clear() + _config_yaml_overrides.clear() + + +def test_yaml_overrides_excludes_type_and_path(): + """Test that type and path fields are not included in YAML overrides.""" + config = { + "policy": {"path": "lerobot/smolvla_base", "type": "smolvla", "lr": 5e-5}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + extract_path_fields_from_config(config_path, ["policy"]) + + overrides = get_yaml_overrides("policy") + assert not any("path=" in o for o in overrides) + assert not any("type=" in o for o in overrides) + assert any("lr=" in o for o in overrides) + + _config_path_args.clear() + _config_yaml_overrides.clear() + + +def test_get_yaml_overrides_empty_when_path_only(): + """Test that get_yaml_overrides returns [] when policy had only a path field.""" + config = { + "policy": {"path": "lerobot/smolvla_base"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config, f) + config_path = f.name + + _config_path_args.clear() + _config_yaml_overrides.clear() + extract_path_fields_from_config(config_path, ["policy"]) + + assert get_yaml_overrides("policy") == [] + + _config_path_args.clear() + _config_yaml_overrides.clear() + + +def test_flatten_bool_values(): + """Test that boolean values are serialized as lowercase strings for draccus.""" + d = {"push_to_hub": True, "use_rabc": False, "lr": 0.001, "name": "test"} + args = _flatten_to_cli_args(d) + assert "--push_to_hub=true" in args + assert "--use_rabc=false" in args + assert "--lr=0.001" in args + assert "--name=test" in args + + +def test_flatten_none_values_skipped(): + """Test that None values are not included in flattened args.""" + d = {"lr": 0.001, "path_override": None, "name": "test"} + args = _flatten_to_cli_args(d) + assert any("lr=" in a for a in args) + assert any("name=" in a for a in args) + assert not any("path_override" in a for a in args) + + +def test_flatten_nested_with_bools(): + """Test that bools in nested dicts are handled correctly.""" + d = {"optimizer": {"use_warmup": True, "lr": 0.01}} + args = _flatten_to_cli_args(d) + assert "--optimizer.use_warmup=true" in args + assert "--optimizer.lr=0.01" in args