diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index e6cf40b9b..7e02f879f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -29,7 +29,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle -from lerobot.envs.factory import make_env +from lerobot.envs.factory import make_env, make_env_pre_post_processors from lerobot.envs.utils import close_envs from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_pre_post_processors @@ -259,6 +259,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") if cfg.env is not None: logging.info(f"{cfg.env.task=}") + logging.info("Creating environment processors") + env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env) logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info(f"{dataset.num_episodes=}") @@ -385,6 +387,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): eval_info = eval_policy_all( envs=eval_env, # dict[suite][task_id] -> vec_env policy=accelerator.unwrap_model(policy), + env_preprocessor=env_preprocessor, + env_postprocessor=env_postprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=cfg.eval.n_episodes,