mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 12:47:18 +00:00
refactor(configs): untangle config_path/resume resolution in validate()
Split the re-parse HACK block in TrainPipelineConfig.validate() into focused helpers (_resolve_pretrained_from_cli, _resolve_resume_checkpoint) that handle the policy path, reward-model path, and resume config_path as separate, readable units. Behavior-preserving.
This commit is contained in:
@@ -139,10 +139,17 @@ class TrainPipelineConfig(HubMixin):
|
||||
return self.reward_model # type: ignore[return-value]
|
||||
return self.policy # type: ignore[return-value]
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
def _resolve_pretrained_from_cli(self) -> None:
|
||||
"""Resolve the pretrained source passed on the CLI into a loaded config.
|
||||
|
||||
The pretrained paths (`--policy.path`, `--reward_model.path`) and
|
||||
`--config_path` are only recoverable by re-reading the CLI args: draccus
|
||||
has already consumed them by the time `validate()` runs, so they are not
|
||||
reflected on `self`. Exactly one source applies, in priority order:
|
||||
reward-model path, policy path, then resume.
|
||||
"""
|
||||
reward_model_path = parser.get_path_arg("reward_model")
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
|
||||
if reward_model_path:
|
||||
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||
@@ -151,31 +158,35 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
overrides = parser.get_yaml_overrides("policy") + (parser.get_cli_overrides("policy") or [])
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=overrides)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
self._resolve_resume_checkpoint()
|
||||
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
def _resolve_resume_checkpoint(self) -> None:
|
||||
"""Point the trainable config at the checkpoint named by `--config_path`."""
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
def validate(self) -> None:
|
||||
self._resolve_pretrained_from_cli()
|
||||
|
||||
if self.policy is None and self.reward_model is None:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user