diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 235352cd8..2ab170a1f 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -31,7 +31,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_processor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters from lerobot.scripts.eval import eval_policy @@ -140,6 +140,7 @@ def train(cfg: TrainPipelineConfig): cfg=cfg.policy, ds_meta=dataset.meta, ) + preprocessor, _ = make_processor(cfg.policy, cfg.policy.pretrained_path) logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) @@ -203,6 +204,7 @@ def train(cfg: TrainPipelineConfig): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) + batch = preprocessor(batch) train_tracker.dataloading_s = time.perf_counter() - start_time for key in batch: @@ -284,6 +286,7 @@ def train(cfg: TrainPipelineConfig): if cfg.policy.push_to_hub: policy.push_model_to_hub(cfg) + preprocessor.push_to_hub(cfg.policy.repo_id) def main():