mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix test, add to train
This commit is contained in:
@@ -29,7 +29,7 @@ from lerobot.configs.train import TrainPipelineConfig
|
|||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.datasets.utils import cycle
|
from lerobot.datasets.utils import cycle
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||||
from lerobot.envs.utils import close_envs
|
from lerobot.envs.utils import close_envs
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
@@ -259,6 +259,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||||
if cfg.env is not None:
|
if cfg.env is not None:
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info(f"{cfg.env.task=}")
|
||||||
|
logging.info("Creating environment processors")
|
||||||
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||||
logging.info(f"{dataset.num_episodes=}")
|
logging.info(f"{dataset.num_episodes=}")
|
||||||
@@ -385,6 +387,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
eval_info = eval_policy_all(
|
eval_info = eval_policy_all(
|
||||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||||
policy=accelerator.unwrap_model(policy),
|
policy=accelerator.unwrap_model(policy),
|
||||||
|
env_preprocessor=env_preprocessor,
|
||||||
|
env_postprocessor=env_postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
n_episodes=cfg.eval.n_episodes,
|
n_episodes=cfg.eval.n_episodes,
|
||||||
|
|||||||
Reference in New Issue
Block a user