refactor(configs): untangle config_path/resume resolution in validate()

Split the re-parse HACK block in TrainPipelineConfig.validate() into focused
helpers (_resolve_pretrained_from_cli, _resolve_resume_checkpoint) that handle
the policy path, reward-model path, and resume config_path as separate,
readable units. Behavior-preserving.
This commit is contained in:
Nicolas Rabault
2026-06-24 10:15:30 +02:00
parent 6256e69c29
commit 955b172585
+35 -24
View File
@@ -139,10 +139,17 @@ class TrainPipelineConfig(HubMixin):
return self.reward_model # type: ignore[return-value]
return self.policy # type: ignore[return-value]
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
def _resolve_pretrained_from_cli(self) -> None:
"""Resolve the pretrained source passed on the CLI into a loaded config.
The pretrained paths (`--policy.path`, `--reward_model.path`) and
`--config_path` are only recoverable by re-reading the CLI args: draccus
has already consumed them by the time `validate()` runs, so they are not
reflected on `self`. Exactly one source applies, in priority order:
reward-model path, policy path, then resume.
"""
reward_model_path = parser.get_path_arg("reward_model")
policy_path = parser.get_path_arg("policy")
if reward_model_path:
cli_overrides = parser.get_cli_overrides("reward_model")
@@ -151,31 +158,35 @@ class TrainPipelineConfig(HubMixin):
)
self.reward_model.pretrained_path = str(Path(reward_model_path))
elif policy_path:
yaml_overrides = parser.get_yaml_overrides("policy")
cli_overrides = parser.get_cli_overrides("policy") or []
self.policy = PreTrainedConfig.from_pretrained(
policy_path, cli_overrides=yaml_overrides + cli_overrides
)
overrides = parser.get_yaml_overrides("policy") + (parser.get_cli_overrides("policy") or [])
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=overrides)
self.policy.pretrained_path = Path(policy_path)
elif self.resume:
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
self._resolve_resume_checkpoint()
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
def _resolve_resume_checkpoint(self) -> None:
"""Point the trainable config at the checkpoint named by `--config_path`."""
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
if self.reward_model is not None:
self.reward_model.pretrained_path = str(policy_dir)
self.checkpoint_path = policy_dir.parent
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
if self.reward_model is not None:
self.reward_model.pretrained_path = str(policy_dir)
self.checkpoint_path = policy_dir.parent
def validate(self) -> None:
self._resolve_pretrained_from_cli()
if self.policy is None and self.reward_model is None:
raise ValueError(