Compare commits

..

1 Commits

Author SHA1 Message Date
Khalil Meftah 0d1d5e0a86 feat(hub): add pretrained_revision to pin Hub model versions
- Add pretrained_revision field to PreTrainedConfig (policies) and
RewardModelConfig (reward models), and thread it through make_policy(),
make_pre_post_processors(), and make_reward_model() so that weights and
processor configs can be loaded from a specific Hub commit, branch, or
tag. Defaults to None (latest version, preserving current behavior).
Dataset and env hub loading already supported revision pinning.
2026-06-15 11:58:57 +02:00
7 changed files with 20 additions and 18 deletions
+9 -17
View File
@@ -180,32 +180,24 @@ class WandBLogger:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
batch_data = {}
for k, v in d.items():
# Skip the custom step key here, it's added to the batch below.
if custom_step_key is not None and k == custom_step_key:
continue
if isinstance(v, list):
for i, elem in enumerate(v):
if isinstance(elem, (int | float)):
batch_data[f"{mode}/{k}_{i}"] = elem
continue
if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)
continue
batch_data[f"{mode}/{k}"] = v
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if batch_data:
if custom_step_key is not None:
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
self._wandb.log(batch_data)
else:
self._wandb.log(data=batch_data, step=step)
value_custom_step = d[custom_step_key]
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
self._wandb.log(data)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
+2
View File
@@ -79,6 +79,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: Path | None = None
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
pretrained_revision: str | None = None
def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device):
+2
View File
@@ -56,6 +56,8 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
device: str | None = None
pretrained_path: str | None = None
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
pretrained_revision: str | None = None
push_to_hub: bool = False
repo_id: str | None = None
+1 -1
View File
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
+4
View File
@@ -252,6 +252,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
def make_pre_post_processors(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
pretrained_revision: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -309,6 +310,7 @@ def make_pre_post_processors(
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=batch_to_transition,
to_output=transition_to_batch,
revision=pretrained_revision,
)
postprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
@@ -318,6 +320,7 @@ def make_pre_post_processors(
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
revision=pretrained_revision,
)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
@@ -557,6 +560,7 @@ def make_policy(
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
kwargs["revision"] = cfg.pretrained_revision
policy = policy_cls.from_pretrained(**kwargs)
elif cfg.pretrained_path and cfg.use_peft:
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
+1
View File
@@ -124,6 +124,7 @@ def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel
if cfg.pretrained_path:
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
kwargs["revision"] = cfg.pretrained_revision
reward_model = reward_cls.from_pretrained(**kwargs)
else:
reward_model = reward_cls(**kwargs)
+1
View File
@@ -345,6 +345,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
pretrained_revision=getattr(cfg.policy, "pretrained_revision", None),
**processor_kwargs,
)