mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions
- Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline. - Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow.
This commit is contained in:
@@ -57,6 +57,7 @@ from dataclasses import asdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
@@ -71,7 +72,9 @@ from lerobot.configs.eval import EvalPipelineConfig
|
|||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
from lerobot.processor.core import TransitionKey
|
from lerobot.processor.core import TransitionKey
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
from lerobot.utils.io_utils import write_video
|
from lerobot.utils.io_utils import write_video
|
||||||
@@ -88,6 +91,8 @@ def rollout(
|
|||||||
policy: PreTrainedPolicy,
|
policy: PreTrainedPolicy,
|
||||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||||
|
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||||
|
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||||
seeds: list[int] | None = None,
|
seeds: list[int] | None = None,
|
||||||
return_observations: bool = False,
|
return_observations: bool = False,
|
||||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||||
@@ -151,6 +156,7 @@ def rollout(
|
|||||||
while not np.all(done):
|
while not np.all(done):
|
||||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
|
observation = preprocessor(observation)
|
||||||
if return_observations:
|
if return_observations:
|
||||||
all_observations.append(deepcopy(observation))
|
all_observations.append(deepcopy(observation))
|
||||||
|
|
||||||
@@ -161,9 +167,11 @@ def rollout(
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||||
|
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||||
|
|
||||||
# Convert to CPU / numpy.
|
# Convert to CPU / numpy.
|
||||||
action: np.ndarray = action.to("cpu").numpy()
|
action: np.ndarray = action.to("cpu").numpy()
|
||||||
|
action: np.ndarray = action.to("cpu").numpy()
|
||||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||||
|
|
||||||
# Apply the next action.
|
# Apply the next action.
|
||||||
@@ -222,6 +230,8 @@ def eval_policy(
|
|||||||
policy: PreTrainedPolicy,
|
policy: PreTrainedPolicy,
|
||||||
preprocessor: PolicyProcessorPipeline,
|
preprocessor: PolicyProcessorPipeline,
|
||||||
postprocessor: PolicyProcessorPipeline,
|
postprocessor: PolicyProcessorPipeline,
|
||||||
|
preprocessor: PolicyProcessorPipeline,
|
||||||
|
postprocessor: PolicyProcessorPipeline,
|
||||||
n_episodes: int,
|
n_episodes: int,
|
||||||
max_episodes_rendered: int = 0,
|
max_episodes_rendered: int = 0,
|
||||||
videos_dir: Path | None = None,
|
videos_dir: Path | None = None,
|
||||||
@@ -298,6 +308,10 @@ def eval_policy(
|
|||||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||||
)
|
)
|
||||||
rollout_data = rollout(
|
rollout_data = rollout(
|
||||||
|
env=env,
|
||||||
|
policy=policy,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
env=env,
|
env=env,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
@@ -484,13 +498,22 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
env_cfg=cfg.env,
|
env_cfg=cfg.env,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
policy.eval()
|
policy.eval()
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
||||||
)
|
)
|
||||||
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
|
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy(
|
info = eval_policy(
|
||||||
|
env=env,
|
||||||
|
policy=policy,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
n_episodes=cfg.eval.n_episodes,
|
||||||
env=env,
|
env=env,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
|||||||
Reference in New Issue
Block a user