add env processor

This commit is contained in:
Jade Choghari
2025-11-18 17:27:43 +01:00
parent 8915c6cd25
commit b257e02ccf
2 changed files with 47 additions and 1 deletions
+28
View File
@@ -14,12 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.processor.observation_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -33,6 +36,31 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]:
"""
Create a preprocessor pipeline for environment observations.
This function creates a processor pipeline that transforms raw environment
observations into the format expected by policies. By default, it returns
an identity processor that does nothing. For specific environments like
LIBERO, it adds environment-specific processing steps.
Args:
env_cfg: The configuration of the environment.
Returns:
A PolicyProcessorPipeline that processes environment observations.
"""
# For LIBERO environments, add the LiberoProcessorStep
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
return PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
# For all other environments, return an identity processor (does nothing)
return PolicyProcessorPipeline(steps=[])
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
+19 -1
View File
@@ -71,7 +71,7 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
@@ -94,6 +94,7 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
seeds: list[int] | None = None,
@@ -165,6 +166,10 @@ def rollout(
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
@@ -239,6 +244,7 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -319,6 +325,7 @@ def eval_policy(
rollout_data = rollout(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
seeds=list(seeds) if seeds else None,
@@ -517,10 +524,15 @@ def eval_main(cfg: EvalPipelineConfig):
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
# Create environment-specific preprocessor (e.g., for LIBERO environments)
env_preprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
@@ -561,6 +573,7 @@ def eval_one(
env: gym.vector.VectorEnv,
*,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -576,6 +589,7 @@ def eval_one(
task_result = eval_policy(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -600,6 +614,7 @@ def run_one(
env,
*,
policy,
env_preprocessor,
preprocessor,
postprocessor,
n_episodes: int,
@@ -622,6 +637,7 @@ def run_one(
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -639,6 +655,7 @@ def run_one(
def eval_policy_all(
envs: dict[str, dict[int, gym.vector.VectorEnv]],
policy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -694,6 +711,7 @@ def eval_policy_all(
task_runner = partial(
run_one,
policy=policy,
env_preprocessor=env_preprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,