precommit

This commit is contained in:
Jade Choghari
2025-09-17 11:20:16 +02:00
parent 2c17433f4d
commit aa517b5780
4 changed files with 14 additions and 10 deletions
+1 -1
View File
@@ -55,4 +55,4 @@ python src/lerobot/scripts/eval.py \
# --num_trials_per_task 10 \ # --num_trials_per_task 10 \
# --video_out_path "data/libero/videos" \ # --video_out_path "data/libero/videos" \
# --device "cuda" \ # --device "cuda" \
# --seed 7 # --seed 7
+1 -1
View File
@@ -323,4 +323,4 @@ class LiberoEnv(EnvConfig):
return { return {
"obs_type": self.obs_type, "obs_type": self.obs_type,
"render_mode": self.render_mode, "render_mode": self.render_mode,
} }
+9 -5
View File
@@ -672,11 +672,15 @@ def eval_policy_all(
""" """
if max_parallel_tasks == 1: if max_parallel_tasks == 1:
yield from _eval_monotask( yield from _eval_monotask(
envs, policy, preprocessor=preprocessor, envs,
postprocessor=postprocessor, n_episodes=n_episodes, policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered, max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir, return_episode_data=return_episode_data, videos_dir=videos_dir,
start_seed=start_seed return_episode_data=return_episode_data,
start_seed=start_seed,
) )
else: else:
yield from _eval_parallel( yield from _eval_parallel(
@@ -753,4 +757,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
+3 -3
View File
@@ -181,7 +181,7 @@ def train(cfg: TrainPipelineConfig):
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
) )
logging.info("Creating optimizer and scheduler") logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
@@ -294,7 +294,7 @@ def train(cfg: TrainPipelineConfig):
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
): ):
eval_info = eval_policy_all( eval_info = eval_policy_all(
env=eval_env, # dict[suite][task_id] -> vec_env env=eval_env, # dict[suite][task_id] -> vec_env
policy=policy, policy=policy,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
@@ -348,4 +348,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()