inject dataset stats for pretrained models

This commit is contained in:
Pepijn
2025-09-25 10:05:00 +02:00
parent 9b3669e87e
commit 2740420d87
+12 -2
View File
@@ -180,15 +180,25 @@ def train(cfg: TrainPipelineConfig):
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
if cfg.policy.pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {"device_processor": {"device": device.type}}
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
"normalizer_processor": {"stats": dataset.meta.stats},
}
postprocessor_kwargs["postprocessor_overrides"] = {
"unnormalizer_processor": {"stats": dataset.meta.stats}
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
logging.info("Creating optimizer and scheduler")