From b196f04d48549be7fa4918a71ec1c903cbca1aba Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 18 Nov 2025 18:29:19 +0100 Subject: [PATCH] more changes --- docs/source/env_processor.mdx | 0 src/lerobot/envs/factory.py | 33 ++-- src/lerobot/processor/env_processor.py | 152 ++++++++++++++++++ .../processor/observation_processor.py | 131 --------------- src/lerobot/scripts/lerobot_eval.py | 18 ++- 5 files changed, 189 insertions(+), 145 deletions(-) create mode 100644 docs/source/env_processor.mdx create mode 100644 src/lerobot/processor/env_processor.py diff --git a/docs/source/env_processor.mdx b/docs/source/env_processor.mdx new file mode 100644 index 000000000..e69de29bb diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 927be754e..9e73a57b9 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -21,7 +21,7 @@ from gymnasium.envs.registration import registry as gym_registry from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result -from lerobot.processor.observation_processor import LiberoProcessorStep +from lerobot.processor.env_processor import LiberoProcessorStep from lerobot.processor.pipeline import PolicyProcessorPipeline @@ -38,27 +38,36 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: def make_env_pre_post_processors( env_cfg: EnvConfig, -) -> PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]: +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], +]: """ - Create a preprocessor pipeline for environment observations. + Create preprocessor and postprocessor pipelines for environment observations. - This function creates a processor pipeline that transforms raw environment - observations into the format expected by policies. By default, it returns - an identity processor that does nothing. For specific environments like - LIBERO, it adds environment-specific processing steps. + This function creates processor pipelines that transform raw environment + observations and actions. By default, it returns identity processors that do nothing. + For specific environments like LIBERO, it adds environment-specific processing steps. Args: env_cfg: The configuration of the environment. Returns: - A PolicyProcessorPipeline that processes environment observations. + A tuple containing: + - preprocessor: Pipeline that processes environment observations + - postprocessor: Pipeline that processes environment outputs (currently identity) """ - # For LIBERO environments, add the LiberoProcessorStep + # For LIBERO environments, add the LiberoProcessorStep to preprocessor if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type: - return PolicyProcessorPipeline(steps=[LiberoProcessorStep()]) + preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()]) + else: + # For all other environments, return an identity preprocessor (does nothing) + preprocessor = PolicyProcessorPipeline(steps=[]) - # For all other environments, return an identity processor (does nothing) - return PolicyProcessorPipeline(steps=[]) + # Postprocessor is currently identity for all environments + postprocessor = PolicyProcessorPipeline(steps=[]) + + return preprocessor, postprocessor def make_env( diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py new file mode 100644 index 000000000..fd0b8c41f --- /dev/null +++ b/src/lerobot/processor/env_processor.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from dataclasses import dataclass + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE + +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry + +@dataclass +@ProcessorStepRegistry.register(name="libero_processor") +class LiberoProcessorStep(ObservationProcessorStep): + """ + Processes LIBERO observations into the LeRobot format. + + This step handles the specific observation structure from LIBERO environments, + which includes nested robot_state dictionaries and image observations. + + **State Processing:** + - Processes the `robot_state` dictionary which contains nested end-effector, + gripper, and joint information. + - Extracts and concatenates: + - End-effector position (3D) + - End-effector quaternion converted to axis-angle (3D) + - Gripper joint positions (2D) + - Maps the concatenated state to `"observation.state"`. + + **Image Processing:** + - Rotates images by 180 degrees by flipping both height and width dimensions. + - This accounts for the HuggingFaceVLA/libero camera orientation convention. + """ + + def _process_observation(self, observation): + """ + Processes both image and robot_state observations from LIBERO. + """ + processed_obs = observation.copy() + for key in list(processed_obs.keys()): + if key.startswith(f"{OBS_IMAGES}."): + img = processed_obs[key] + + # Flip both H and W + img = torch.flip(img, dims=[2, 3]) + + processed_obs[key] = img + # Process robot_state into a flat state vector + if "observation.robot_state" in processed_obs: + robot_state = processed_obs.pop("observation.robot_state") + + # Extract components + eef_pos = robot_state["eef"]["pos"] # (B, 3,) + eef_quat = robot_state["eef"]["quat"] # (B, 4,) + gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,) + + # Convert quaternion to axis-angle + eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3) + # Concatenate into a single state vector + state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1) + + # ensure float32 + state = state.float() + if state.dim() == 1: + state = state.unsqueeze(0) + + processed_obs[OBS_STATE] = state + return processed_obs + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Transforms feature keys from the LIBERO format to the LeRobot standard. + """ + new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {} + + # copy over non-STATE features + for ft, feats in features.items(): + if ft != PipelineFeatureType.STATE: + new_features[ft] = feats.copy() + + # rebuild STATE features + state_feats = {} + + # add our new flattened state + state_feats["observation.state"] = PolicyFeature( + key="observation.state", + shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] + dtype="float32", + description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), + ) + + new_features[PipelineFeatureType.STATE] = state_feats + + return new_features + + def observation(self, observation): + return self._process_observation(observation) + + def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor: + """ + Convert batched quaternions to axis-angle format. + Only accepts torch tensors of shape (B, 4). + + Args: + quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format + + Returns: + Tensor: (B, 3) axis-angle vectors + + Raises: + TypeError: if input is not a torch tensor + ValueError: if shape is not (B, 4) + """ + + if not isinstance(quat, torch.Tensor): + raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}") + + if quat.ndim != 2 or quat.shape[1] != 4: + raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}") + + quat = quat.to(dtype=torch.float32) + device = quat.device + batch_size = quat.shape[0] + + w = quat[:, 3].clamp(-1.0, 1.0) + + den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0)) + + result = torch.zeros((batch_size, 3), device=device) + + mask = den > 1e-10 + + if mask.any(): + angle = 2.0 * torch.acos(w[mask]) # (M,) + axis = quat[mask, :3] / den[mask].unsqueeze(1) + result[mask] = axis * angle.unsqueeze(1) + + return result diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index c1a1a3024..d22d8fb96 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -204,134 +204,3 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): new_features[src_ft][key] = feat return new_features - - -@dataclass -@ProcessorStepRegistry.register(name="libero_processor") -class LiberoProcessorStep(ObservationProcessorStep): - """ - Processes LIBERO observations into the LeRobot format. - - This step handles the specific observation structure from LIBERO environments, - which includes nested robot_state dictionaries and image observations. - - **State Processing:** - - Processes the `robot_state` dictionary which contains nested end-effector, - gripper, and joint information. - - Extracts and concatenates: - - End-effector position (3D) - - End-effector quaternion converted to axis-angle (3D) - - Gripper joint positions (2D) - - Maps the concatenated state to `"observation.state"`. - - **Image Processing:** - - Rotates images by 180 degrees by flipping both height and width dimensions. - - This accounts for the HuggingFaceVLA/libero camera orientation convention. - """ - - def _process_observation(self, observation): - """ - Processes both image and robot_state observations from LIBERO. - """ - processed_obs = observation.copy() - for key in list(processed_obs.keys()): - if key.startswith(f"{OBS_IMAGES}."): - img = processed_obs[key] - - # Flip both H and W - img = torch.flip(img, dims=[2, 3]) - - processed_obs[key] = img - # Process robot_state into a flat state vector - if "observation.robot_state" in processed_obs: - robot_state = processed_obs.pop("observation.robot_state") - - # Extract components - eef_pos = robot_state["eef"]["pos"] # (B, 3,) - eef_quat = robot_state["eef"]["quat"] # (B, 4,) - gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,) - - # Convert quaternion to axis-angle - eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3) - # Concatenate into a single state vector - state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1) - - # ensure float32 - state = state.float() - if state.dim() == 1: - state = state.unsqueeze(0) - - processed_obs[OBS_STATE] = state - return processed_obs - - def transform_features( - self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] - ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - """ - Transforms feature keys from the LIBERO format to the LeRobot standard. - """ - new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {} - - # copy over non-STATE features - for ft, feats in features.items(): - if ft != PipelineFeatureType.STATE: - new_features[ft] = feats.copy() - - # rebuild STATE features - state_feats = {} - - # add our new flattened state - state_feats["observation.state"] = PolicyFeature( - key="observation.state", - shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] - dtype="float32", - description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), - ) - - new_features[PipelineFeatureType.STATE] = state_feats - - return new_features - - def observation(self, observation): - return self._process_observation(observation) - - def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor: - """ - Convert batched quaternions to axis-angle format. - Only accepts torch tensors of shape (B, 4). - - Args: - quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format - - Returns: - Tensor: (B, 3) axis-angle vectors - - Raises: - TypeError: if input is not a torch tensor - ValueError: if shape is not (B, 4) - """ - - if not isinstance(quat, torch.Tensor): - raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}") - - if quat.ndim != 2 or quat.shape[1] != 4: - raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}") - - quat = quat.to(dtype=torch.float32) - device = quat.device - batch_size = quat.shape[0] - - w = quat[:, 3].clamp(-1.0, 1.0) - - den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0)) - - result = torch.zeros((batch_size, 3), device=device) - - mask = den > 1e-10 - - if mask.any(): - angle = 2.0 * torch.acos(w[mask]) # (M,) - axis = quat[mask, :3] / den[mask].unsqueeze(1) - result[mask] = axis * angle.unsqueeze(1) - - return result diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index c1b821635..4cf9c4095 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -95,6 +95,7 @@ def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], seeds: list[int] | None = None, @@ -175,6 +176,10 @@ def rollout( action = policy.select_action(observation) action = postprocessor(action) + action_transition = {"action": action} + action_transition = env_postprocessor(action_transition) + action = action_transition["action"] + # Convert to CPU / numpy. action_numpy: np.ndarray = action.to("cpu").numpy() assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" @@ -245,6 +250,7 @@ def eval_policy( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, @@ -326,6 +332,7 @@ def eval_policy( env=env, policy=policy, env_preprocessor=env_preprocessor, + env_postprocessor=env_postprocessor, preprocessor=preprocessor, postprocessor=postprocessor, seeds=list(seeds) if seeds else None, @@ -525,14 +532,15 @@ def eval_main(cfg: EvalPipelineConfig): preprocessor_overrides=preprocessor_overrides, ) - # Create environment-specific preprocessor (e.g., for LIBERO environments) - env_preprocessor = make_env_pre_post_processors(env_cfg=cfg.env) + # Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments) + env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env) with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy_all( envs=envs, policy=policy, env_preprocessor=env_preprocessor, + env_postprocessor=env_postprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=cfg.eval.n_episodes, @@ -574,6 +582,7 @@ def eval_one( *, policy: PreTrainedPolicy, env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, @@ -590,6 +599,7 @@ def eval_one( env=env, policy=policy, env_preprocessor=env_preprocessor, + env_postprocessor=env_postprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=n_episodes, @@ -615,6 +625,7 @@ def run_one( *, policy, env_preprocessor, + env_postprocessor, preprocessor, postprocessor, n_episodes: int, @@ -638,6 +649,7 @@ def run_one( env, policy=policy, env_preprocessor=env_preprocessor, + env_postprocessor=env_postprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=n_episodes, @@ -656,6 +668,7 @@ def eval_policy_all( envs: dict[str, dict[int, gym.vector.VectorEnv]], policy, env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, @@ -712,6 +725,7 @@ def eval_policy_all( run_one, policy=policy, env_preprocessor=env_preprocessor, + env_postprocessor=env_postprocessor, preprocessor=preprocessor, postprocessor=postprocessor, n_episodes=n_episodes,