mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
precommit
This commit is contained in:
+1
-1
@@ -55,4 +55,4 @@ python src/lerobot/scripts/eval.py \
|
|||||||
# --num_trials_per_task 10 \
|
# --num_trials_per_task 10 \
|
||||||
# --video_out_path "data/libero/videos" \
|
# --video_out_path "data/libero/videos" \
|
||||||
# --device "cuda" \
|
# --device "cuda" \
|
||||||
# --seed 7
|
# --seed 7
|
||||||
|
|||||||
@@ -323,4 +323,4 @@ class LiberoEnv(EnvConfig):
|
|||||||
return {
|
return {
|
||||||
"obs_type": self.obs_type,
|
"obs_type": self.obs_type,
|
||||||
"render_mode": self.render_mode,
|
"render_mode": self.render_mode,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -672,11 +672,15 @@ def eval_policy_all(
|
|||||||
"""
|
"""
|
||||||
if max_parallel_tasks == 1:
|
if max_parallel_tasks == 1:
|
||||||
yield from _eval_monotask(
|
yield from _eval_monotask(
|
||||||
envs, policy, preprocessor=preprocessor,
|
envs,
|
||||||
postprocessor=postprocessor, n_episodes=n_episodes,
|
policy,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
n_episodes=n_episodes,
|
||||||
max_episodes_rendered=max_episodes_rendered,
|
max_episodes_rendered=max_episodes_rendered,
|
||||||
videos_dir=videos_dir, return_episode_data=return_episode_data,
|
videos_dir=videos_dir,
|
||||||
start_seed=start_seed
|
return_episode_data=return_episode_data,
|
||||||
|
start_seed=start_seed,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield from _eval_parallel(
|
yield from _eval_parallel(
|
||||||
@@ -753,4 +757,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
|
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||||
@@ -294,7 +294,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||||
):
|
):
|
||||||
eval_info = eval_policy_all(
|
eval_info = eval_policy_all(
|
||||||
env=eval_env, # dict[suite][task_id] -> vec_env
|
env=eval_env, # dict[suite][task_id] -> vec_env
|
||||||
policy=policy,
|
policy=policy,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
@@ -348,4 +348,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user