diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index c89d25bca..f7fe267b1 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -139,10 +139,17 @@ class TrainPipelineConfig(HubMixin): return self.reward_model # type: ignore[return-value] return self.policy # type: ignore[return-value] - def validate(self) -> None: - # HACK: We parse again the cli args here to get the pretrained paths if there was some. - policy_path = parser.get_path_arg("policy") + def _resolve_pretrained_from_cli(self) -> None: + """Resolve the pretrained source passed on the CLI into a loaded config. + + The pretrained paths (`--policy.path`, `--reward_model.path`) and + `--config_path` are only recoverable by re-reading the CLI args: draccus + has already consumed them by the time `validate()` runs, so they are not + reflected on `self`. Exactly one source applies, in priority order: + reward-model path, policy path, then resume. + """ reward_model_path = parser.get_path_arg("reward_model") + policy_path = parser.get_path_arg("policy") if reward_model_path: cli_overrides = parser.get_cli_overrides("reward_model") @@ -151,31 +158,35 @@ class TrainPipelineConfig(HubMixin): ) self.reward_model.pretrained_path = str(Path(reward_model_path)) elif policy_path: - yaml_overrides = parser.get_yaml_overrides("policy") - cli_overrides = parser.get_cli_overrides("policy") or [] - self.policy = PreTrainedConfig.from_pretrained( - policy_path, cli_overrides=yaml_overrides + cli_overrides - ) + overrides = parser.get_yaml_overrides("policy") + (parser.get_cli_overrides("policy") or []) + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=overrides) self.policy.pretrained_path = Path(policy_path) elif self.resume: - config_path = parser.parse_arg("config_path") - if not config_path: - raise ValueError( - f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}" - ) + self._resolve_resume_checkpoint() - if not Path(config_path).resolve().exists(): - raise NotADirectoryError( - f"{config_path=} is expected to be a local path. " - "Resuming from the hub is not supported for now." - ) + def _resolve_resume_checkpoint(self) -> None: + """Point the trainable config at the checkpoint named by `--config_path`.""" + config_path = parser.parse_arg("config_path") + if not config_path: + raise ValueError( + f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}" + ) - policy_dir = Path(config_path).parent - if self.policy is not None: - self.policy.pretrained_path = policy_dir - if self.reward_model is not None: - self.reward_model.pretrained_path = str(policy_dir) - self.checkpoint_path = policy_dir.parent + if not Path(config_path).resolve().exists(): + raise NotADirectoryError( + f"{config_path=} is expected to be a local path. " + "Resuming from the hub is not supported for now." + ) + + policy_dir = Path(config_path).parent + if self.policy is not None: + self.policy.pretrained_path = policy_dir + if self.reward_model is not None: + self.reward_model.pretrained_path = str(policy_dir) + self.checkpoint_path = policy_dir.parent + + def validate(self) -> None: + self._resolve_pretrained_from_cli() if self.policy is None and self.reward_model is None: raise ValueError(