mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-30 22:57:00 +00:00
fix video paths and train.py
This commit is contained in:
@@ -257,7 +257,6 @@ def eval_policy(
|
||||
# Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly
|
||||
# divisible by env.num_envs we end up discarding some data in the last batch.
|
||||
n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0)
|
||||
print("n_batches", n_batches)
|
||||
|
||||
# Keep track of some metrics.
|
||||
sum_rewards = []
|
||||
@@ -565,12 +564,16 @@ def eval_policy_multitask(
|
||||
def eval_task(task_group, task_id, env):
|
||||
"""Evaluates a single task in parallel."""
|
||||
print(f"Evaluating: task_group: {task_group}, task_id: {task_id} ...")
|
||||
# jadechoghari : added multi video eval support
|
||||
if videos_dir is not None:
|
||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
task_result = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
n_episodes,
|
||||
max_episodes_rendered,
|
||||
videos_dir,
|
||||
task_videos_dir,
|
||||
return_episode_data,
|
||||
start_seed,
|
||||
# verbose=verbose,
|
||||
|
||||
@@ -262,17 +262,15 @@ def train(cfg: TrainPipelineConfig):
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
)
|
||||
aggregated_results = eval_info["overall"]["aggregated"]
|
||||
aggregated = eval_info["overall"]["aggregated"]
|
||||
# Print per-suite stats
|
||||
for task_group, task_group_info in eval_info.items():
|
||||
if task_group == "overall":
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
breakpoint()
|
||||
else:
|
||||
print("START EVAL")
|
||||
breakpoint()
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
@@ -302,7 +300,13 @@ def train(cfg: TrainPipelineConfig):
|
||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
# added by jade, close all env in multi eval setup
|
||||
if cfg.env.multitask_eval:
|
||||
for task_group, envs_dict in eval_env.items():
|
||||
for idx, env in envs_dict.items():
|
||||
env.close()
|
||||
else:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
|
||||
Reference in New Issue
Block a user