diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index caa6f91ed..355a6560f 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -14,12 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +from typing import Any import gymnasium as gym from gymnasium.envs.registration import registry as gym_registry from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result +from lerobot.processor.observation_processor import LiberoProcessorStep +from lerobot.processor.pipeline import PolicyProcessorPipeline def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -33,6 +36,31 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: raise ValueError(f"Policy type '{env_type}' is not available.") +def make_env_pre_post_processors( + env_cfg: EnvConfig, +) -> PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]: + """ + Create a preprocessor pipeline for environment observations. + + This function creates a processor pipeline that transforms raw environment + observations into the format expected by policies. By default, it returns + an identity processor that does nothing. For specific environments like + LIBERO, it adds environment-specific processing steps. + + Args: + env_cfg: The configuration of the environment. + + Returns: + A PolicyProcessorPipeline that processes environment observations. + """ + # For LIBERO environments, add the LiberoProcessorStep + if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type: + return PolicyProcessorPipeline(steps=[LiberoProcessorStep()]) + + # For all other environments, return an identity processor (does nothing) + return PolicyProcessorPipeline(steps=[]) + + def make_env( cfg: EnvConfig | str, n_envs: int = 1, diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 0d66fa1aa..c9d1b49f9 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -71,7 +71,7 @@ from tqdm import trange from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig -from lerobot.envs.factory import make_env +from lerobot.envs.factory import make_env, make_env_pre_post_processors from lerobot.envs.utils import ( add_envs_task, check_env_attributes_and_types, @@ -94,6 +94,7 @@ from lerobot.utils.utils import ( def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, + env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], seeds: list[int] | None = None, @@ -165,6 +166,10 @@ def rollout( # Infer "task" from attributes of environments. # TODO: works with SyncVectorEnv but not AsyncVectorEnv observation = add_envs_task(env, observation) + + # Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO) + observation = env_preprocessor(observation) + observation = preprocessor(observation) with torch.inference_mode(): action = policy.select_action(observation) @@ -239,6 +244,7 @@ def rollout( def eval_policy( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, + env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, @@ -319,6 +325,7 @@ def eval_policy( rollout_data = rollout( env=env, policy=policy, + env_preprocessor=env_preprocessor, preprocessor=preprocessor, postprocessor=postprocessor, seeds=list(seeds) if seeds else None, @@ -517,10 +524,15 @@ def eval_main(cfg: EvalPipelineConfig): pretrained_path=cfg.policy.pretrained_path, preprocessor_overrides=preprocessor_overrides, ) + + # Create environment-specific preprocessor (e.g., for LIBERO environments) + env_preprocessor = make_env_pre_post_processors(env_cfg=cfg.env) + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy_all( envs=envs, policy=policy, + env_preprocessor=env_preprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=cfg.eval.n_episodes, @@ -561,6 +573,7 @@ def eval_one( env: gym.vector.VectorEnv, *, policy: PreTrainedPolicy, + env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, @@ -576,6 +589,7 @@ def eval_one( task_result = eval_policy( env=env, policy=policy, + env_preprocessor=env_preprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=n_episodes, @@ -600,6 +614,7 @@ def run_one( env, *, policy, + env_preprocessor, preprocessor, postprocessor, n_episodes: int, @@ -622,6 +637,7 @@ def run_one( metrics = eval_one( env, policy=policy, + env_preprocessor=env_preprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=n_episodes, @@ -639,6 +655,7 @@ def run_one( def eval_policy_all( envs: dict[str, dict[int, gym.vector.VectorEnv]], policy, + env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, @@ -694,6 +711,7 @@ def eval_policy_all( task_runner = partial( run_one, policy=policy, + env_preprocessor=env_preprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=n_episodes,