Compare commits

...

4 Commits

Author SHA1 Message Date
AdilZouitine 15960f0b5e refactor(utils): enhance task handling in add_envs_task function
- Improved the `add_envs_task` function to validate the output of `task_description` and `task` calls, ensuring they return lists of strings.
- Removed the use of `else` statement for environments without language instructions, simplifying the logic and enhancing readability.
- Streamlined the observation dictionary handling by ensuring consistent data types for task attributes.
2025-09-10 10:05:43 +02:00
AdilZouitine 8b43339563 debug 2025-09-10 10:05:43 +02:00
AdilZouitine 5dababd21e refactor(eval): remove redundant observation device conversion in rollout function
- Eliminated unnecessary device conversion for the observation dictionary within the `rollout` function, streamlining the code and enhancing readability.
- This change simplifies the observation handling process, aligning with the preference for clearer solutions.
2025-09-10 10:05:43 +02:00
AdilZouitine cbc46467b3 refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions
- Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline.
- Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow.
2025-09-10 10:05:43 +02:00
+21
View File
@@ -57,6 +57,7 @@ from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Any
from typing import Any
import einops
import gymnasium as gym
@@ -71,6 +72,7 @@ from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.core import TransitionKey
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -88,6 +90,8 @@ def rollout(
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -161,9 +165,11 @@ def rollout(
with torch.inference_mode():
action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
# Convert to CPU / numpy.
action: np.ndarray = action.to("cpu").numpy()
action: np.ndarray = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
@@ -222,6 +228,8 @@ def eval_policy(
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
@@ -298,6 +306,10 @@ def eval_policy(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
env=env,
policy=policy,
preprocessor=preprocessor,
@@ -484,13 +496,22 @@ def eval_main(cfg: EvalPipelineConfig):
env_cfg=cfg.env,
)
policy.eval()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
env=env,
policy=policy,
preprocessor=preprocessor,