precommit

This commit is contained in:
Jade Choghari
2025-09-17 11:20:16 +02:00
parent 2c17433f4d
commit aa517b5780
4 changed files with 14 additions and 10 deletions
+8 -4
View File
@@ -672,11 +672,15 @@ def eval_policy_all(
""" """
if max_parallel_tasks == 1: if max_parallel_tasks == 1:
yield from _eval_monotask( yield from _eval_monotask(
envs, policy, preprocessor=preprocessor, envs,
postprocessor=postprocessor, n_episodes=n_episodes, policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered, max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir, return_episode_data=return_episode_data, videos_dir=videos_dir,
start_seed=start_seed return_episode_data=return_episode_data,
start_seed=start_seed,
) )
else: else:
yield from _eval_parallel( yield from _eval_parallel(
+1 -1
View File
@@ -294,7 +294,7 @@ def train(cfg: TrainPipelineConfig):
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
): ):
eval_info = eval_policy_all( eval_info = eval_policy_all(
env=eval_env, # dict[suite][task_id] -> vec_env env=eval_env, # dict[suite][task_id] -> vec_env
policy=policy, policy=policy,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,