From 2740420d87caf09cf1ac35c83321c582ff116dd0 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 25 Sep 2025 10:05:00 +0200 Subject: [PATCH] inject dataset stats for pretrained models --- src/lerobot/scripts/train.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 8d0c49bce..9e18286f1 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -180,15 +180,25 @@ def train(cfg: TrainPipelineConfig): # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} + postprocessor_kwargs = {} if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats if cfg.policy.pretrained_path is not None: - processor_kwargs["preprocessor_overrides"] = {"device_processor": {"device": device.type}} + processor_kwargs["preprocessor_overrides"] = { + "device_processor": {"device": device.type}, + "normalizer_processor": {"stats": dataset.meta.stats}, + } + postprocessor_kwargs["postprocessor_overrides"] = { + "unnormalizer_processor": {"stats": dataset.meta.stats} + } preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + **processor_kwargs, + **postprocessor_kwargs, ) logging.info("Creating optimizer and scheduler")