mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
inject dataset stats for pretrained models
This commit is contained in:
@@ -180,15 +180,25 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
|
postprocessor_kwargs = {}
|
||||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||||
# Only provide dataset_stats when not resuming from saved processor state
|
# Only provide dataset_stats when not resuming from saved processor state
|
||||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||||
|
|
||||||
if cfg.policy.pretrained_path is not None:
|
if cfg.policy.pretrained_path is not None:
|
||||||
processor_kwargs["preprocessor_overrides"] = {"device_processor": {"device": device.type}}
|
processor_kwargs["preprocessor_overrides"] = {
|
||||||
|
"device_processor": {"device": device.type},
|
||||||
|
"normalizer_processor": {"stats": dataset.meta.stats},
|
||||||
|
}
|
||||||
|
postprocessor_kwargs["postprocessor_overrides"] = {
|
||||||
|
"unnormalizer_processor": {"stats": dataset.meta.stats}
|
||||||
|
}
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
|
policy_cfg=cfg.policy,
|
||||||
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
|
**processor_kwargs,
|
||||||
|
**postprocessor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
|
|||||||
Reference in New Issue
Block a user