refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions

- Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline.
- Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow.
This commit is contained in:
AdilZouitine
2025-09-09 12:59:58 +02:00
parent e881fb6678
commit cbc46467b3
+23
View File
@@ -57,6 +57,7 @@ from dataclasses import asdict
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Any from typing import Any
from typing import Any
import einops import einops
import gymnasium as gym import gymnasium as gym
@@ -71,7 +72,9 @@ from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor.core import TransitionKey from lerobot.processor.core import TransitionKey
from lerobot.processor.pipeline import PolicyProcessorPipeline from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.utils.io_utils import write_video from lerobot.utils.io_utils import write_video
@@ -88,6 +91,8 @@ def rollout(
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]], postprocessor: PolicyProcessorPipeline[dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
seeds: list[int] | None = None, seeds: list[int] | None = None,
return_observations: bool = False, return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -151,6 +156,7 @@ def rollout(
while not np.all(done): while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format. # Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation) observation = preprocess_observation(observation)
observation = preprocessor(observation)
if return_observations: if return_observations:
all_observations.append(deepcopy(observation)) all_observations.append(deepcopy(observation))
@@ -161,9 +167,11 @@ def rollout(
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION] action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
# Convert to CPU / numpy. # Convert to CPU / numpy.
action: np.ndarray = action.to("cpu").numpy() action: np.ndarray = action.to("cpu").numpy()
action: np.ndarray = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action. # Apply the next action.
@@ -222,6 +230,8 @@ def eval_policy(
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline, preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
n_episodes: int, n_episodes: int,
max_episodes_rendered: int = 0, max_episodes_rendered: int = 0,
videos_dir: Path | None = None, videos_dir: Path | None = None,
@@ -298,6 +308,10 @@ def eval_policy(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
) )
rollout_data = rollout( rollout_data = rollout(
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
env=env, env=env,
policy=policy, policy=policy,
preprocessor=preprocessor, preprocessor=preprocessor,
@@ -484,13 +498,22 @@ def eval_main(cfg: EvalPipelineConfig):
env_cfg=cfg.env, env_cfg=cfg.env,
) )
policy.eval() policy.eval()
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
) )
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy( info = eval_policy(
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
env=env, env=env,
policy=policy, policy=policy,
preprocessor=preprocessor, preprocessor=preprocessor,