From aa517b57807b605db106cd9d7d909e5d8c0aa777 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Wed, 17 Sep 2025 11:20:16 +0200 Subject: [PATCH] precommit --- examples/test.sh | 2 +- src/lerobot/envs/configs.py | 2 +- src/lerobot/scripts/eval.py | 14 +++++++++----- src/lerobot/scripts/train.py | 6 +++--- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/test.sh b/examples/test.sh index a34f85719..7c76945be 100644 --- a/examples/test.sh +++ b/examples/test.sh @@ -55,4 +55,4 @@ python src/lerobot/scripts/eval.py \ # --num_trials_per_task 10 \ # --video_out_path "data/libero/videos" \ # --device "cuda" \ -# --seed 7 \ No newline at end of file +# --seed 7 diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index c3fdbbff5..3b0932cc6 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -323,4 +323,4 @@ class LiberoEnv(EnvConfig): return { "obs_type": self.obs_type, "render_mode": self.render_mode, - } \ No newline at end of file + } diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 0600ddb19..ba8dec6a4 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -672,11 +672,15 @@ def eval_policy_all( """ if max_parallel_tasks == 1: yield from _eval_monotask( - envs, policy, preprocessor=preprocessor, - postprocessor=postprocessor, n_episodes=n_episodes, + envs, + policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=n_episodes, max_episodes_rendered=max_episodes_rendered, - videos_dir=videos_dir, return_episode_data=return_episode_data, - start_seed=start_seed + videos_dir=videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, ) else: yield from _eval_parallel( @@ -753,4 +757,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index b31adaaf0..1655afa85 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -181,7 +181,7 @@ def train(cfg: TrainPipelineConfig): preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs ) - + logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) @@ -294,7 +294,7 @@ def train(cfg: TrainPipelineConfig): torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): eval_info = eval_policy_all( - env=eval_env, # dict[suite][task_id] -> vec_env + env=eval_env, # dict[suite][task_id] -> vec_env policy=policy, preprocessor=preprocessor, postprocessor=postprocessor, @@ -348,4 +348,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()