From f14ac5d4864b5d0fa7fbebab2a41c79162afa53d Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 10 Jul 2025 11:11:57 +0200 Subject: [PATCH] feat(train): Integrate preprocessor into training pipeline --- src/lerobot/scripts/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 235352cd8..2ab170a1f 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -31,7 +31,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_processor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters from lerobot.scripts.eval import eval_policy @@ -140,6 +140,7 @@ def train(cfg: TrainPipelineConfig): cfg=cfg.policy, ds_meta=dataset.meta, ) + preprocessor, _ = make_processor(cfg.policy, cfg.policy.pretrained_path) logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) @@ -203,6 +204,7 @@ def train(cfg: TrainPipelineConfig): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) + batch = preprocessor(batch) train_tracker.dataloading_s = time.perf_counter() - start_time for key in batch: @@ -284,6 +286,7 @@ def train(cfg: TrainPipelineConfig): if cfg.policy.push_to_hub: policy.push_model_to_hub(cfg) + preprocessor.push_to_hub(cfg.policy.repo_id) def main():