more changes

This commit is contained in:
Jade Choghari
2025-11-18 18:29:19 +01:00
parent 91768a9879
commit b196f04d48
5 changed files with 189 additions and 145 deletions
View File
+21 -12
View File
@@ -21,7 +21,7 @@ from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.processor.observation_processor import LiberoProcessorStep from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -38,27 +38,36 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
def make_env_pre_post_processors( def make_env_pre_post_processors(
env_cfg: EnvConfig, env_cfg: EnvConfig,
) -> PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]: ) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
""" """
Create a preprocessor pipeline for environment observations. Create preprocessor and postprocessor pipelines for environment observations.
This function creates a processor pipeline that transforms raw environment This function creates processor pipelines that transform raw environment
observations into the format expected by policies. By default, it returns observations and actions. By default, it returns identity processors that do nothing.
an identity processor that does nothing. For specific environments like For specific environments like LIBERO, it adds environment-specific processing steps.
LIBERO, it adds environment-specific processing steps.
Args: Args:
env_cfg: The configuration of the environment. env_cfg: The configuration of the environment.
Returns: Returns:
A PolicyProcessorPipeline that processes environment observations. A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
""" """
# For LIBERO environments, add the LiberoProcessorStep # For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type: if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
return PolicyProcessorPipeline(steps=[LiberoProcessorStep()]) preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor (does nothing)
preprocessor = PolicyProcessorPipeline(steps=[])
# For all other environments, return an identity processor (does nothing) # Postprocessor is currently identity for all environments
return PolicyProcessorPipeline(steps=[]) postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
def make_env( def make_env(
+152
View File
@@ -0,0 +1,152 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from dataclasses import dataclass
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def observation(self, observation):
return self._process_observation(observation)
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
Only accepts torch tensors of shape (B, 4).
Args:
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
Returns:
Tensor: (B, 3) axis-angle vectors
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 4)
"""
if not isinstance(quat, torch.Tensor):
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
if quat.ndim != 2 or quat.shape[1] != 4:
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
quat = quat.to(dtype=torch.float32)
device = quat.device
batch_size = quat.shape[0]
w = quat[:, 3].clamp(-1.0, 1.0)
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
result = torch.zeros((batch_size, 3), device=device)
mask = den > 1e-10
if mask.any():
angle = 2.0 * torch.acos(w[mask]) # (M,)
axis = quat[mask, :3] / den[mask].unsqueeze(1)
result[mask] = axis * angle.unsqueeze(1)
return result
@@ -204,134 +204,3 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
new_features[src_ft][key] = feat new_features[src_ft][key] = feat
return new_features return new_features
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def observation(self, observation):
return self._process_observation(observation)
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
Only accepts torch tensors of shape (B, 4).
Args:
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
Returns:
Tensor: (B, 3) axis-angle vectors
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 4)
"""
if not isinstance(quat, torch.Tensor):
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
if quat.ndim != 2 or quat.shape[1] != 4:
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
quat = quat.to(dtype=torch.float32)
device = quat.device
batch_size = quat.shape[0]
w = quat[:, 3].clamp(-1.0, 1.0)
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
result = torch.zeros((batch_size, 3), device=device)
mask = den > 1e-10
if mask.any():
angle = 2.0 * torch.acos(w[mask]) # (M,)
axis = quat[mask, :3] / den[mask].unsqueeze(1)
result[mask] = axis * angle.unsqueeze(1)
return result
+16 -2
View File
@@ -95,6 +95,7 @@ def rollout(
env: gym.vector.VectorEnv, env: gym.vector.VectorEnv,
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
seeds: list[int] | None = None, seeds: list[int] | None = None,
@@ -175,6 +176,10 @@ def rollout(
action = policy.select_action(observation) action = policy.select_action(observation)
action = postprocessor(action) action = postprocessor(action)
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# Convert to CPU / numpy. # Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy() action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
@@ -245,6 +250,7 @@ def eval_policy(
env: gym.vector.VectorEnv, env: gym.vector.VectorEnv,
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int, n_episodes: int,
@@ -326,6 +332,7 @@ def eval_policy(
env=env, env=env,
policy=policy, policy=policy,
env_preprocessor=env_preprocessor, env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
seeds=list(seeds) if seeds else None, seeds=list(seeds) if seeds else None,
@@ -525,14 +532,15 @@ def eval_main(cfg: EvalPipelineConfig):
preprocessor_overrides=preprocessor_overrides, preprocessor_overrides=preprocessor_overrides,
) )
# Create environment-specific preprocessor (e.g., for LIBERO environments) # Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor = make_env_pre_post_processors(env_cfg=cfg.env) env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all( info = eval_policy_all(
envs=envs, envs=envs,
policy=policy, policy=policy,
env_preprocessor=env_preprocessor, env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes, n_episodes=cfg.eval.n_episodes,
@@ -574,6 +582,7 @@ def eval_one(
*, *,
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int, n_episodes: int,
@@ -590,6 +599,7 @@ def eval_one(
env=env, env=env,
policy=policy, policy=policy,
env_preprocessor=env_preprocessor, env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
n_episodes=n_episodes, n_episodes=n_episodes,
@@ -615,6 +625,7 @@ def run_one(
*, *,
policy, policy,
env_preprocessor, env_preprocessor,
env_postprocessor,
preprocessor, preprocessor,
postprocessor, postprocessor,
n_episodes: int, n_episodes: int,
@@ -638,6 +649,7 @@ def run_one(
env, env,
policy=policy, policy=policy,
env_preprocessor=env_preprocessor, env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
n_episodes=n_episodes, n_episodes=n_episodes,
@@ -656,6 +668,7 @@ def eval_policy_all(
envs: dict[str, dict[int, gym.vector.VectorEnv]], envs: dict[str, dict[int, gym.vector.VectorEnv]],
policy, policy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int, n_episodes: int,
@@ -712,6 +725,7 @@ def eval_policy_all(
run_one, run_one,
policy=policy, policy=policy,
env_preprocessor=env_preprocessor, env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
n_episodes=n_episodes, n_episodes=n_episodes,