mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8e77625cf1 | |||
| 514e3cd83d | |||
| 25ef3c520c | |||
| 464f14a0f4 | |||
| 61f8c8a0f2 | |||
| b196f04d48 | |||
| 91768a9879 | |||
| 4fb41d3e5a | |||
| b257e02ccf | |||
| 8915c6cd25 | |||
| 0ed2f87fba | |||
| a068618faf | |||
| 769eb27c87 | |||
| c3b8f65a8c | |||
| 33a8d31af0 | |||
| 6c9f169996 | |||
| 9979b62c52 | |||
| e91e48b79c | |||
| b4b5d057b1 | |||
| 9a115c303c | |||
| e5efb6b6dc | |||
| 4c67330430 |
@@ -59,6 +59,8 @@
|
||||
title: Implement your own processor
|
||||
- local: processors_robots_teleop
|
||||
title: Processors for Robots and Teleoperators
|
||||
- local: env_processor
|
||||
title: Environment Processors
|
||||
title: "Robot Processors"
|
||||
- sections:
|
||||
- local: so101
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
# Environment Processors
|
||||
|
||||
Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
|
||||
|
||||
## Why Environment Processors?
|
||||
|
||||
When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
|
||||
|
||||
1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
|
||||
2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
|
||||
3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
|
||||
|
||||
Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
|
||||
|
||||
## The Processing Pipeline
|
||||
|
||||
Here's how data flows through the complete processing pipeline during evaluation:
|
||||
|
||||
```python
|
||||
# In lerobot_eval.py rollout() function:
|
||||
|
||||
# 1. Raw environment observation (numpy arrays, various formats)
|
||||
raw_observation = env.step(action)
|
||||
|
||||
# 2. Convert numpy to torch, normalize images [0,1]
|
||||
observation = preprocess_observation(raw_observation)
|
||||
|
||||
# 3. Add task metadata (for multi-task environments)
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
|
||||
# - Flatten robot states
|
||||
# - Rotate images to match dataset conventions
|
||||
# - Handle environment-specific coordinate systems
|
||||
observation = env_preprocessor(observation)
|
||||
|
||||
# 5. POLICY-SPECIFIC preprocessing
|
||||
# - Normalize with dataset statistics
|
||||
# - Add batch dimensions
|
||||
# - Move to GPU
|
||||
# - Tokenize language instructions
|
||||
observation = preprocessor(observation)
|
||||
|
||||
# 6. Policy inference
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# 7. POLICY-SPECIFIC postprocessing
|
||||
# - Unnormalize actions
|
||||
# - Remove batch dimensions
|
||||
action = postprocessor(action)
|
||||
|
||||
# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
|
||||
# - Convert action formats if needed
|
||||
# - Apply environment-specific constraints
|
||||
action_transition = {"action": action}
|
||||
action_transition = env_postprocessor(action_transition)
|
||||
action = action_transition["action"]
|
||||
|
||||
# 9. Execute in environment
|
||||
env.step(action)
|
||||
```
|
||||
|
||||
## The Benefits
|
||||
|
||||
### 1. **Separation of Concerns**
|
||||
|
||||
Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
|
||||
|
||||
```python
|
||||
# ❌ Before: Mixed concerns
|
||||
class LiberoVLAPolicy:
|
||||
def preprocess(self, obs):
|
||||
# Environment-specific: Flatten robot state (shouldn't be in policy!)
|
||||
state = self._flatten_robot_state(obs["robot_state"])
|
||||
# Policy-specific: Normalize with dataset stats
|
||||
state = self.normalizer(state)
|
||||
return state
|
||||
|
||||
# ✅ After: Clear separation
|
||||
# Environment processor: Handles LIBERO's nested robot state
|
||||
env_preprocessor = LiberoProcessorStep() # Flattens robot_state
|
||||
|
||||
# Policy processor: Handles model requirements
|
||||
policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
|
||||
```
|
||||
|
||||
### 2. **Flexibility and Reusability**
|
||||
|
||||
The same policy can work with different environment processors, and the same environment processor can work with different policies:
|
||||
|
||||
```python
|
||||
# Use SmolVLA policy with LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
|
||||
|
||||
# Or use ACT policy with the same LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
|
||||
```
|
||||
|
||||
### 3. **Easier Experimentation**
|
||||
|
||||
Want to try different state representations for LIBERO? Just create a new processor:
|
||||
|
||||
```python
|
||||
# Original: 8D state (pos + quat→axisangle + gripper)
|
||||
@ProcessorStepRegistry.register("libero_processor")
|
||||
class LiberoProcessorStep(ObservationProcessorStep):
|
||||
def _process_observation(self, obs):
|
||||
eef_pos = robot_state["eef"]["pos"] # 3D
|
||||
eef_axisangle = quat2axisangle(quat) # 3D
|
||||
gripper = robot_state["gripper"]["qpos"] # 2D
|
||||
state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
|
||||
return state
|
||||
|
||||
# Experiment: Add velocity for better control
|
||||
@ProcessorStepRegistry.register("libero_velocity_processor")
|
||||
class LiberoVelocityProcessorStep(ObservationProcessorStep):
|
||||
def _process_observation(self, obs):
|
||||
# Include velocities for 14D state
|
||||
eef_pos = robot_state["eef"]["pos"] # 3D
|
||||
eef_axisangle = quat2axisangle(quat) # 3D
|
||||
eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
|
||||
gripper_pos = robot_state["gripper"]["qpos"] # 2D
|
||||
gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
|
||||
state = torch.cat([eef_pos, eef_axisangle, eef_vel,
|
||||
gripper_pos, gripper_vel], dim=-1) # 14D
|
||||
return state
|
||||
```
|
||||
|
||||
### 4. **Cleaner Environment Code**
|
||||
|
||||
Environments expose **all available data** without needing to know what downstream models will use:
|
||||
|
||||
```python
|
||||
# LIBERO environment exposes full robot state
|
||||
observation = {
|
||||
"pixels": {"image": img, "image2": img2},
|
||||
"robot_state": {
|
||||
"eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
|
||||
"gripper": {"qpos": ..., "qvel": ...},
|
||||
"joints": {"pos": ..., "vel": ...}
|
||||
}
|
||||
}
|
||||
|
||||
# Environment processor decides what to use
|
||||
# Policy processor handles model-specific transformations
|
||||
```
|
||||
|
||||
## Using Environment Processors
|
||||
|
||||
### Factory Function
|
||||
|
||||
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
|
||||
|
||||
```python
|
||||
from lerobot.envs.factory import make_env_pre_post_processors
|
||||
from lerobot.envs.configs import LiberoEnv, PushtEnv
|
||||
|
||||
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
||||
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
|
||||
# For other environments: Returns identity processors (no-op)
|
||||
pusht_cfg = PushtEnv()
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
|
||||
```
|
||||
|
||||
### Implementation in `envs/factory.py`
|
||||
|
||||
```python
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
Create preprocessor and postprocessor pipelines for environment observations.
|
||||
|
||||
Args:
|
||||
env_cfg: The configuration of the environment.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- preprocessor: Pipeline that processes environment observations
|
||||
- postprocessor: Pipeline that processes environment outputs
|
||||
"""
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
else:
|
||||
# For all other environments, return an identity preprocessor
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
# Postprocessor is currently identity for all environments
|
||||
# Future: Could add environment-specific action transformations
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
|
||||
### Integration in Evaluation
|
||||
|
||||
In `lerobot_eval.py`, the environment processors are created once and used throughout:
|
||||
|
||||
```python
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment
|
||||
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
|
||||
# Create policy
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
|
||||
|
||||
# Create policy processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
)
|
||||
|
||||
# Create environment processors (NEW!)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
|
||||
# Run evaluation with both processor types
|
||||
eval_policy_all(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor, # Environment-specific
|
||||
env_postprocessor=env_postprocessor, # Environment-specific
|
||||
preprocessor=preprocessor, # Policy-specific
|
||||
postprocessor=postprocessor, # Policy-specific
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
)
|
||||
```
|
||||
|
||||
## Example: LIBERO Environment Processor
|
||||
|
||||
The `LiberoProcessorStep` demonstrates a real-world environment processor:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import ObservationProcessorStep
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="libero_processor")
|
||||
class LiberoProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes LIBERO observations into the LeRobot format.
|
||||
|
||||
**State Processing:**
|
||||
- Extracts end-effector position (3D)
|
||||
- Converts quaternion to axis-angle representation (3D)
|
||||
- Extracts gripper joint positions (2D)
|
||||
- Concatenates into 8D state vector
|
||||
|
||||
**Image Processing:**
|
||||
- Rotates images 180° to match HuggingFaceVLA/libero convention
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation):
|
||||
processed_obs = observation.copy()
|
||||
|
||||
# Process images: Flip 180° for camera convention
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith("observation.images."):
|
||||
img = processed_obs[key]
|
||||
img = torch.flip(img, dims=[2, 3]) # Flip H and W
|
||||
processed_obs[key] = img
|
||||
|
||||
# Process robot_state: Flatten to 8D vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3)
|
||||
eef_quat = robot_state["eef"]["quat"] # (B, 4)
|
||||
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
|
||||
|
||||
# Convert quaternion to axis-angle
|
||||
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
|
||||
|
||||
# Concatenate into single state vector
|
||||
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
|
||||
state = state.float()
|
||||
|
||||
processed_obs["observation.state"] = state
|
||||
|
||||
return processed_obs
|
||||
```
|
||||
|
||||
### Why These Transformations?
|
||||
|
||||
1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
|
||||
|
||||
2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
|
||||
- Selects the relevant components (pos, quat, gripper)
|
||||
- Converts quaternion to axis-angle (more suitable for learning)
|
||||
- Flattens to a single 8D vector that policies expect
|
||||
|
||||
3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
|
||||
|
||||
## Adding Environment Processors for New Environments
|
||||
|
||||
To add environment processors for a new environment:
|
||||
|
||||
### 1. Create the Processor Step
|
||||
|
||||
```python
|
||||
# In src/lerobot/processor/env_processor.py
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="myenv_processor")
|
||||
class MyEnvProcessorStep(ObservationProcessorStep):
|
||||
"""Process observations from MyEnv."""
|
||||
|
||||
def _process_observation(self, observation):
|
||||
processed = observation.copy()
|
||||
|
||||
# Your environment-specific transformations
|
||||
if "myenv.specific.state" in processed:
|
||||
state = processed.pop("myenv.specific.state")
|
||||
# Transform to standard format
|
||||
processed["observation.state"] = self._transform_state(state)
|
||||
|
||||
return processed
|
||||
```
|
||||
|
||||
### 2. Update the Factory
|
||||
|
||||
```python
|
||||
# In src/lerobot/envs/factory.py
|
||||
|
||||
def make_env_pre_post_processors(env_cfg: EnvConfig):
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
|
||||
else:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
|
||||
### 3. Use in Evaluation
|
||||
|
||||
No changes needed! The evaluation script automatically uses the appropriate processor:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/my_policy \
|
||||
--env.type=myenv \ # Automatically uses MyEnvProcessorStep
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
## Future: Environment Postprocessors
|
||||
|
||||
Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
|
||||
|
||||
### Action Space Transformations
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class MyEnvActionPostprocessor(ProcessorStep):
|
||||
"""Convert policy actions to environment-specific format."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition["action"]
|
||||
|
||||
# Example: Convert from Cartesian to joint space
|
||||
if self.action_space == "joint":
|
||||
action = self.ik_solver(action)
|
||||
|
||||
# Example: Apply environment-specific safety limits
|
||||
action = torch.clamp(action, self.min_action, self.max_action)
|
||||
|
||||
transition["action"] = action
|
||||
return transition
|
||||
```
|
||||
|
||||
### Coordinate System Conversions
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CoordinateTransformPostprocessor(ProcessorStep):
|
||||
"""Transform actions between coordinate systems."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition["action"]
|
||||
|
||||
# Example: Policy outputs in world frame, env expects base frame
|
||||
action = self.world_to_base_transform(action)
|
||||
|
||||
transition["action"] = action
|
||||
return transition
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
|
||||
|
||||
2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
|
||||
|
||||
3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
|
||||
|
||||
4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
|
||||
|
||||
5. **Test independently**: Environment processors should be testable without loading full policies or environments.
|
||||
|
||||
## Summary
|
||||
|
||||
Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
|
||||
|
||||
- ✅ Enables easy experimentation with different state representations
|
||||
- ✅ Allows policies to work seamlessly across different environments
|
||||
- ✅ Keeps environment code focused on simulation/hardware interface
|
||||
- ✅ Makes processor pipelines more maintainable and debuggable
|
||||
- ✅ Follows the single responsibility principle
|
||||
|
||||
The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.
|
||||
@@ -0,0 +1,6 @@
|
||||
python ./examples/dataset/convert_hdf5_lerobot.py \
|
||||
--src-paths /fsx/jade_choghari/XVLA-Soft-Fold/0808_12am_stage_1_stage2new_new_cam_very_slow_no_sleeve \
|
||||
--output-path /fsx/jade_choghari/new-data \
|
||||
--executor local \
|
||||
--tasks-per-job 3 \
|
||||
--workers 10
|
||||
@@ -0,0 +1,437 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
# import ray
|
||||
# from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from lerobot.datasets.aggregate import (
|
||||
aggregate_data,
|
||||
aggregate_metadata,
|
||||
aggregate_stats,
|
||||
aggregate_videos,
|
||||
validate_all_metadata,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
XVLA_SOFT_FOLD_FEATURES = {
|
||||
"observation.images.cam_high": {
|
||||
"dtype": "video",
|
||||
"names": ["height", "width", "channels"],
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.images.cam_left_wrist": {
|
||||
"dtype": "video",
|
||||
"names": ["height", "width", "channels"],
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.images.cam_right_wrist": {
|
||||
"dtype": "video",
|
||||
"names": ["height", "width", "channels"],
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
|
||||
"observation.states.eef_euler": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,), # 14 = 7 joints per arm × 2 arms OR 14-d state representation
|
||||
"names": {"values": [f"eef_euler_{i}" for i in range(14)]},
|
||||
},
|
||||
|
||||
"observation.states.eef_quaternion": {
|
||||
"dtype": "float32",
|
||||
"shape": (16,), # 16 = 8 quaternion floats per arm × 2 arms
|
||||
"names": {"values": [f"eef_quat_{i}" for i in range(16)]},
|
||||
},
|
||||
|
||||
"observation.states.eef_6d": {
|
||||
"dtype": "float32",
|
||||
"shape": (20,), # 20 = pos(3) + rot6d(6) + extra dims
|
||||
"names": {"values": [f"eef6d_{i}" for i in range(20)]},
|
||||
},
|
||||
|
||||
"observation.states.eef_left_time": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {"values": ["eef_left_time"]},
|
||||
},
|
||||
|
||||
"observation.states.eef_right_time": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {"values": ["eef_right_time"]},
|
||||
},
|
||||
|
||||
"observation.states.qpos": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,), # 7 per arm × 2 arms
|
||||
"names": {"motors": [f"qpos_{i}" for i in range(14)]},
|
||||
},
|
||||
|
||||
"observation.states.qvel": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,),
|
||||
"names": {"motors": [f"qvel_{i}" for i in range(14)]},
|
||||
},
|
||||
|
||||
"observation.states.effort": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,),
|
||||
"names": {"motors": [f"effort_{i}" for i in range(14)]},
|
||||
},
|
||||
|
||||
"observation.states.qpos_left_time": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {"values": ["qpos_left_time"]},
|
||||
},
|
||||
|
||||
"observation.states.qpos_right_time": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {"values": ["qpos_right_time"]},
|
||||
},
|
||||
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,),
|
||||
"names": {"motors": [f"joint_action_{i}" for i in range(14)]},
|
||||
},
|
||||
|
||||
"time_stamp": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {"values": ["global_timestamp"]},
|
||||
},
|
||||
}
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def decode_image(encoded_array):
|
||||
# HDF5 gives you an array of uint8 → convert to raw bytes
|
||||
data = np.asarray(encoded_array, dtype=np.uint8)
|
||||
img = cv2.imdecode(data, cv2.IMREAD_COLOR) # returns HWC BGR
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # convert to RGB
|
||||
return img
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from h5py import File
|
||||
|
||||
|
||||
def load_local_episodes(input_h5: Path):
|
||||
"""
|
||||
Load one XVLA Soft-Fold episode from a single .hdf5 file.
|
||||
This dataset stores ONE episode per file, NOT a /data/ group.
|
||||
"""
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
with h5py.File(input_h5, "r") as f:
|
||||
|
||||
# Determine episode length from any observation vector
|
||||
episode_len = f["observations/eef_6d"].shape[0]
|
||||
|
||||
episode = []
|
||||
|
||||
for i in range(episode_len):
|
||||
frame = {
|
||||
# ----------------------
|
||||
# ROOT-LEVEL
|
||||
# ----------------------
|
||||
"task": "fold the cloth",
|
||||
"time_stamp": np.array([f["time_stamp"][i]], dtype=np.float32),
|
||||
|
||||
# ----------------------
|
||||
# OBSERVATIONS
|
||||
# ----------------------
|
||||
"observation": {
|
||||
"images": {
|
||||
"cam_high": f["observations/images/cam_high"][i],
|
||||
"cam_left_wrist": f["observations/images/cam_left_wrist"][i],
|
||||
"cam_right_wrist": f["observations/images/cam_right_wrist"][i],
|
||||
},
|
||||
"states": {
|
||||
"eef_euler": f["observations/eef"][i],
|
||||
"eef_quaternion": f["observations/eef_quaternion"][i],
|
||||
"eef_6d": f["observations/eef_6d"][i],
|
||||
|
||||
"eef_left_time": np.array([f["observations/eef_left_time"][i]], dtype=np.float32),
|
||||
"eef_right_time": np.array([f["observations/eef_right_time"][i]], dtype=np.float32),
|
||||
|
||||
"qpos": f["observations/qpos"][i],
|
||||
"qvel": f["observations/qvel"][i],
|
||||
"effort": f["observations/effort"][i],
|
||||
|
||||
"qpos_left_time": np.array([f["observations/qpos_left_time"][i]], dtype=np.float32),
|
||||
"qpos_right_time": np.array([f["observations/qpos_right_time"][i]], dtype=np.float32),
|
||||
},
|
||||
},
|
||||
|
||||
# ----------------------
|
||||
# ACTION (your joint 14-D)
|
||||
# ----------------------
|
||||
"action": f["action"][i].astype(np.float32),
|
||||
}
|
||||
|
||||
episode.append(frame)
|
||||
|
||||
yield episode
|
||||
|
||||
# from ray.runtime_env import RuntimeEnv
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def setup_logger():
|
||||
import sys
|
||||
|
||||
from datatrove.utils.logging import logger
|
||||
|
||||
logger.remove()
|
||||
logger.add(sys.stdout, level="INFO", colorize=True)
|
||||
return logger
|
||||
|
||||
|
||||
class SaveLerobotDataset(PipelineStep):
|
||||
name = "Save Temp LerobotDataset"
|
||||
type = "libero2lerobot"
|
||||
|
||||
def __init__(self, tasks: list[tuple[Path, Path, str]]):
|
||||
super().__init__()
|
||||
self.tasks = tasks
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
logger = setup_logger()
|
||||
|
||||
input_h5, output_path, task_instruction = self.tasks[rank]
|
||||
|
||||
if output_path.exists():
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=f"{input_h5.parent.name}/{input_h5.name}",
|
||||
root=output_path,
|
||||
fps=20,
|
||||
robot_type="franka",
|
||||
features=XVLA_SOFT_FOLD_FEATURES,
|
||||
)
|
||||
|
||||
logger.info(f"start processing for {input_h5}, saving to {output_path}")
|
||||
|
||||
raw_dataset = load_local_episodes(input_h5)
|
||||
for episode_index, episode_data in enumerate(raw_dataset):
|
||||
with self.track_time("saving episode"):
|
||||
|
||||
for raw_frame in episode_data:
|
||||
frame_data = {
|
||||
"task": task_instruction,
|
||||
|
||||
# ---------------------- IMAGES ----------------------
|
||||
"observation.images.cam_high": decode_image(raw_frame["observation"]["images"]["cam_high"]),
|
||||
"observation.images.cam_left_wrist": decode_image(raw_frame["observation"]["images"]["cam_left_wrist"]),
|
||||
"observation.images.cam_right_wrist": decode_image(raw_frame["observation"]["images"]["cam_right_wrist"]),
|
||||
|
||||
# ---------------------- EEF STATES ----------------------
|
||||
"observation.states.eef_euler": raw_frame["observation"]["states"]["eef_euler"],
|
||||
"observation.states.eef_quaternion": raw_frame["observation"]["states"]["eef_quaternion"],
|
||||
"observation.states.eef_6d": raw_frame["observation"]["states"]["eef_6d"],
|
||||
|
||||
"observation.states.eef_left_time": raw_frame["observation"]["states"]["eef_left_time"],
|
||||
"observation.states.eef_right_time": raw_frame["observation"]["states"]["eef_right_time"],
|
||||
|
||||
# ---------------------- JOINT STATES ----------------------
|
||||
"observation.states.qpos": raw_frame["observation"]["states"]["qpos"],
|
||||
"observation.states.qvel": raw_frame["observation"]["states"]["qvel"],
|
||||
"observation.states.effort": raw_frame["observation"]["states"]["effort"],
|
||||
|
||||
"observation.states.qpos_left_time": raw_frame["observation"]["states"]["qpos_left_time"],
|
||||
"observation.states.qpos_right_time": raw_frame["observation"]["states"]["qpos_right_time"],
|
||||
|
||||
# ---------------------- ACTION ----------------------
|
||||
"action": raw_frame["action"],
|
||||
|
||||
# ---------------------- TIME ----------------------
|
||||
"time_stamp": raw_frame["time_stamp"],
|
||||
}
|
||||
|
||||
dataset.add_frame(frame_data)
|
||||
|
||||
dataset.save_episode()
|
||||
logger.info(f"Processed {dataset.repo_id}, episode {episode_index}, len={len(episode_data)}")
|
||||
|
||||
|
||||
def create_aggr_dataset(raw_dirs: list[Path], aggregated_dir: Path):
|
||||
logger = setup_logger()
|
||||
|
||||
all_metadata = [LeRobotDatasetMetadata("", root=raw_dir) for raw_dir in raw_dirs]
|
||||
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
|
||||
if aggregated_dir.exists():
|
||||
shutil.rmtree(aggregated_dir)
|
||||
|
||||
aggr_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=f"{aggregated_dir.parent.name}/{aggregated_dir.name}",
|
||||
root=aggregated_dir,
|
||||
fps=fps,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
)
|
||||
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||
aggr_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
|
||||
meta_idx = {"chunk": 0, "file": 0}
|
||||
data_idx = {"chunk": 0, "file": 0}
|
||||
videos_idx = {key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys}
|
||||
|
||||
aggr_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(
|
||||
src_meta, aggr_meta, videos_idx, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE
|
||||
)
|
||||
data_idx = aggregate_data(src_meta, aggr_meta, data_idx, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, aggr_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
aggr_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
aggr_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
logger.info("write tasks")
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logger.info("write info")
|
||||
aggr_meta.info.update(
|
||||
{
|
||||
"total_tasks": len(aggr_meta.tasks),
|
||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||
}
|
||||
)
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logger.info("write stats")
|
||||
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
|
||||
write_stats(aggr_meta.stats, aggr_meta.root)
|
||||
|
||||
|
||||
def delete_temp_data(temp_dirs: list[Path]):
|
||||
logger = setup_logger()
|
||||
logger.info("Delete temp data_dir")
|
||||
for temp_dir in temp_dirs:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def main(
|
||||
src_paths: list[Path],
|
||||
output_path: Path,
|
||||
executor: str,
|
||||
cpus_per_task: int,
|
||||
tasks_per_job: int,
|
||||
workers: int,
|
||||
resume_dir: Path = None,
|
||||
debug: bool = False,
|
||||
repo_id: str = None,
|
||||
push_to_hub: bool = False,
|
||||
):
|
||||
tasks = []
|
||||
for src_path in src_paths:
|
||||
for input_h5 in src_path.glob("*.hdf5"):
|
||||
tasks.append(
|
||||
(
|
||||
input_h5,
|
||||
(output_path / (src_path.name + "_temp") / input_h5.stem).resolve(),
|
||||
"fold the cloth", # fixed single task
|
||||
)
|
||||
)
|
||||
if len(src_paths) > 1:
|
||||
aggregate_output_path = output_path / ("_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot")
|
||||
else:
|
||||
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
||||
aggregate_output_path = aggregate_output_path.resolve()
|
||||
|
||||
if debug:
|
||||
executor = "local"
|
||||
workers = 1
|
||||
tasks = tasks[:2]
|
||||
push_to_hub = False
|
||||
|
||||
match executor:
|
||||
case "local":
|
||||
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
|
||||
executor = LocalPipelineExecutor
|
||||
# case "ray":
|
||||
# runtime_env = RuntimeEnv(
|
||||
# env_vars={
|
||||
# "HDF5_USE_FILE_LOCKING": "FALSE",
|
||||
# "HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
|
||||
# "SVT_LOG": "1",
|
||||
# },
|
||||
# )
|
||||
# ray.init(runtime_env=runtime_env)
|
||||
# executor = RayPipelineExecutor
|
||||
case _:
|
||||
raise ValueError(f"Executor {executor} not supported")
|
||||
|
||||
executor_config = {
|
||||
"tasks": len(tasks),
|
||||
"workers": workers,
|
||||
**({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if False else {}),
|
||||
}
|
||||
|
||||
executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_dir).run()
|
||||
create_aggr_dataset([task[1] for task in tasks], aggregate_output_path)
|
||||
delete_temp_data([task[1] for task in tasks])
|
||||
|
||||
for task in tasks:
|
||||
shutil.rmtree(task[1].parent, ignore_errors=True)
|
||||
|
||||
if push_to_hub:
|
||||
assert repo_id is not None
|
||||
tags = ["LeRobot", "libero", "franka"]
|
||||
tags.extend([src_path.name for src_path in src_paths])
|
||||
LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=aggregate_output_path,
|
||||
).push_to_hub(
|
||||
tags=tags,
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
upload_large_folder=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--src-paths", type=Path, nargs="+", required=True)
|
||||
parser.add_argument("--output-path", type=Path, required=True)
|
||||
parser.add_argument("--executor", type=str, choices=["local", "ray"], default="local")
|
||||
parser.add_argument("--cpus-per-task", type=int, default=1)
|
||||
parser.add_argument("--tasks-per-job", type=int, default=1, help="number of concurrent tasks per job, only used for ray")
|
||||
parser.add_argument("--workers", type=int, default=-1, help="number of concurrent jobs to run")
|
||||
parser.add_argument("--resume-dir", type=Path, help="logs directory to resume")
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
parser.add_argument("--repo-id", type=str, help="required when push-to-hub is True")
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="upload to hub")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
@@ -0,0 +1,36 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from h5py import File
|
||||
|
||||
|
||||
def load_local_episodes(input_h5: Path):
|
||||
with File(input_h5, "r") as f:
|
||||
for demo in f["data"].values():
|
||||
demo_len = len(demo["obs/agentview_rgb"])
|
||||
# (-1: open, 1: close) -> (0: close, 1: open)
|
||||
action = np.array(demo["actions"])
|
||||
action = np.concatenate(
|
||||
[
|
||||
action[:, :6],
|
||||
(1 - np.clip(action[:, -1], 0, 1))[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
state = np.concatenate(
|
||||
[
|
||||
np.array(demo["obs/ee_states"]),
|
||||
np.array(demo["obs/gripper_states"]),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
episode = {
|
||||
"observation.images.image": np.array(demo["obs/agentview_rgb"]),
|
||||
"observation.images.wrist_image": np.array(demo["obs/eye_in_hand_rgb"]),
|
||||
"observation.state": np.array(state, dtype=np.float32),
|
||||
"observation.states.ee_state": np.array(demo["obs/ee_states"], dtype=np.float32),
|
||||
"observation.states.joint_state": np.array(demo["obs/joint_states"], dtype=np.float32),
|
||||
"observation.states.gripper_state": np.array(demo["obs/gripper_states"], dtype=np.float32),
|
||||
"action": np.array(action, dtype=np.float32),
|
||||
}
|
||||
yield [{**{k: v[i] for k, v in episode.items()}} for i in range(demo_len)]
|
||||
@@ -246,7 +246,14 @@ class LiberoEnv(EnvConfig):
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"robot_state/eef/pos": f"{OBS_STATE}.eef_pos",
|
||||
"robot_state/eef/quat": f"{OBS_STATE}.eef_quat",
|
||||
"robot_state/eef/mat": f"{OBS_STATE}.eef_mat",
|
||||
"robot_state/eef/axisangle": f"{OBS_STATE}.eef_axisangle",
|
||||
"robot_state/gripper/qpos": f"{OBS_STATE}.gripper_qpos",
|
||||
"robot_state/gripper/qvel": f"{OBS_STATE}.gripper_qvel",
|
||||
"robot_state/joints/pos": f"{OBS_STATE}.joint_pos",
|
||||
"robot_state/joints/vel": f"{OBS_STATE}.joint_vel",
|
||||
"pixels/agentview_image": f"{OBS_IMAGES}.image",
|
||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
|
||||
}
|
||||
@@ -261,13 +268,44 @@ class LiberoEnv(EnvConfig):
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features["robot_state/eef/pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(3,),
|
||||
)
|
||||
self.features["robot_state/eef/quat"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(4,),
|
||||
)
|
||||
self.features["robot_state/eef/mat"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(3, 3),
|
||||
)
|
||||
self.features["robot_state/eef/axisangle"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(3,),
|
||||
)
|
||||
self.features["robot_state/gripper/qpos"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(2,),
|
||||
)
|
||||
self.features["robot_state/gripper/qvel"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(2,),
|
||||
)
|
||||
self.features["robot_state/joints/pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(7,),
|
||||
)
|
||||
self.features["robot_state/joints/vel"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(7,),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
|
||||
@@ -14,12 +14,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -33,6 +36,40 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
Create preprocessor and postprocessor pipelines for environment observations.
|
||||
|
||||
This function creates processor pipelines that transform raw environment
|
||||
observations and actions. By default, it returns identity processors that do nothing.
|
||||
For specific environments like LIBERO, it adds environment-specific processing steps.
|
||||
|
||||
Args:
|
||||
env_cfg: The configuration of the environment.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- preprocessor: Pipeline that processes environment observations
|
||||
- postprocessor: Pipeline that processes environment outputs (currently identity)
|
||||
"""
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
else:
|
||||
# For all other environments, return an identity preprocessor (does nothing)
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
# Postprocessor is currently identity for all environments
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
def make_env(
|
||||
cfg: EnvConfig | str,
|
||||
n_envs: int = 1,
|
||||
|
||||
+74
-20
@@ -175,11 +175,39 @@ class LiberoEnv(gym.Env):
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"agent_pos": spaces.Box(
|
||||
low=AGENT_POS_LOW,
|
||||
high=AGENT_POS_HIGH,
|
||||
shape=(OBS_STATE_DIM,),
|
||||
dtype=np.float64,
|
||||
"robot_state": spaces.Dict(
|
||||
{
|
||||
"eef": spaces.Dict(
|
||||
{
|
||||
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
|
||||
"quat": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
|
||||
),
|
||||
"mat": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
|
||||
),
|
||||
"axisangle": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64
|
||||
),
|
||||
}
|
||||
),
|
||||
"gripper": spaces.Dict(
|
||||
{
|
||||
"qpos": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
|
||||
),
|
||||
"qvel": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
|
||||
),
|
||||
}
|
||||
),
|
||||
"joints": spaces.Dict(
|
||||
{
|
||||
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
|
||||
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -191,6 +219,7 @@ class LiberoEnv(gym.Env):
|
||||
def render(self):
|
||||
raw_obs = self._env.env._get_observations()
|
||||
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||
image = image[::-1, ::-1] # flip both H and W for visualization
|
||||
return image
|
||||
|
||||
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
|
||||
@@ -212,23 +241,50 @@ class LiberoEnv(gym.Env):
|
||||
images = {}
|
||||
for camera_name in self.camera_name:
|
||||
image = raw_obs[camera_name]
|
||||
image = image[::-1, ::-1] # rotate 180 degrees
|
||||
images[self.camera_name_mapping[camera_name]] = image
|
||||
state = np.concatenate(
|
||||
(
|
||||
raw_obs["robot0_eef_pos"],
|
||||
quat2axisangle(raw_obs["robot0_eef_quat"]),
|
||||
raw_obs["robot0_gripper_qpos"],
|
||||
)
|
||||
)
|
||||
agent_pos = state
|
||||
|
||||
eef_pos = raw_obs.get("robot0_eef_pos")
|
||||
eef_quat = raw_obs.get("robot0_eef_quat")
|
||||
|
||||
# rotation matrix from controller
|
||||
eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
|
||||
eef_axisangle = quat2axisangle(eef_quat) if eef_quat is not None else None
|
||||
gripper_qpos = raw_obs.get("robot0_gripper_qpos")
|
||||
gripper_qvel = raw_obs.get("robot0_gripper_qvel")
|
||||
joint_pos = raw_obs.get("robot0_joint_pos")
|
||||
joint_vel = raw_obs.get("robot0_joint_vel")
|
||||
obs = {
|
||||
"pixels": images,
|
||||
"robot_state": {
|
||||
"eef": {
|
||||
"pos": eef_pos, # (3,)
|
||||
"quat": eef_quat, # (4,)
|
||||
"mat": eef_mat, # (3, 3)
|
||||
"axisangle": eef_axisangle, # (3)
|
||||
},
|
||||
"gripper": {
|
||||
"qpos": gripper_qpos, # (2,)
|
||||
"qvel": gripper_qvel, # (2,)
|
||||
},
|
||||
"joints": {
|
||||
"pos": joint_pos, # (7,)
|
||||
"vel": joint_vel, # (7,)
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.obs_type == "pixels":
|
||||
return {"pixels": images.copy()}
|
||||
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
return {
|
||||
"pixels": images.copy(),
|
||||
"agent_pos": agent_pos,
|
||||
}
|
||||
# Validate required fields are present
|
||||
if eef_pos is None or eef_quat is None or gripper_qpos is None:
|
||||
raise ValueError(
|
||||
f"Missing required robot state fields in raw observation. "
|
||||
f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
|
||||
f"gripper_qpos={gripper_qpos is not None}"
|
||||
)
|
||||
return obs
|
||||
|
||||
raise NotImplementedError(
|
||||
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
|
||||
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
|
||||
@@ -355,12 +411,10 @@ def create_libero_envs(
|
||||
print(f"Restricting to task_ids={task_ids_filter}")
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for suite_name in suite_names:
|
||||
suite = _get_suite(suite_name)
|
||||
total = len(suite.tasks)
|
||||
selected = _select_task_ids(total, task_ids_filter)
|
||||
|
||||
if not selected:
|
||||
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||
|
||||
|
||||
@@ -29,10 +29,22 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
def _convert_nested_dict(d):
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _convert_nested_dict(v)
|
||||
elif isinstance(v, np.ndarray):
|
||||
result[k] = torch.from_numpy(v)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
@@ -78,12 +90,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
return_observations[OBS_ENV_STATE] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations[OBS_STATE] = agent_pos
|
||||
if "agent_pos" in observations:
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations[OBS_STATE] = agent_pos
|
||||
|
||||
if "robot_state" in observations:
|
||||
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="libero_processor")
|
||||
class LiberoProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes LIBERO observations into the LeRobot format.
|
||||
|
||||
This step handles the specific observation structure from LIBERO environments,
|
||||
which includes nested robot_state dictionaries and image observations.
|
||||
|
||||
**State Processing:**
|
||||
- Processes the `robot_state` dictionary which contains nested end-effector,
|
||||
gripper, and joint information.
|
||||
- Extracts and concatenates:
|
||||
- End-effector position (3D)
|
||||
- End-effector quaternion converted to axis-angle (3D)
|
||||
- Gripper joint positions (2D)
|
||||
- Maps the concatenated state to `"observation.state"`.
|
||||
|
||||
**Image Processing:**
|
||||
- Rotates images by 180 degrees by flipping both height and width dimensions.
|
||||
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
Processes both image and robot_state observations from LIBERO.
|
||||
"""
|
||||
processed_obs = observation.copy()
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
img = processed_obs[key]
|
||||
|
||||
# Flip both H and W
|
||||
img = torch.flip(img, dims=[2, 3])
|
||||
|
||||
processed_obs[key] = img
|
||||
# Process robot_state into a flat state vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
|
||||
# Extract components
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
|
||||
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
|
||||
|
||||
# Convert quaternion to axis-angle
|
||||
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
|
||||
# Concatenate into a single state vector
|
||||
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
|
||||
|
||||
# ensure float32
|
||||
state = state.float()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
processed_obs[OBS_STATE] = state
|
||||
return processed_obs
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Transforms feature keys from the LIBERO format to the LeRobot standard.
|
||||
"""
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
state_feats = {}
|
||||
|
||||
# add our new flattened state
|
||||
state_feats["observation.state"] = PolicyFeature(
|
||||
key="observation.state",
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
dtype="float32",
|
||||
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
|
||||
)
|
||||
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert batched quaternions to axis-angle format.
|
||||
Only accepts torch tensors of shape (B, 4).
|
||||
|
||||
Args:
|
||||
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
|
||||
|
||||
Returns:
|
||||
Tensor: (B, 3) axis-angle vectors
|
||||
|
||||
Raises:
|
||||
TypeError: if input is not a torch tensor
|
||||
ValueError: if shape is not (B, 4)
|
||||
"""
|
||||
|
||||
if not isinstance(quat, torch.Tensor):
|
||||
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
|
||||
|
||||
if quat.ndim != 2 or quat.shape[1] != 4:
|
||||
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
|
||||
|
||||
quat = quat.to(dtype=torch.float32)
|
||||
device = quat.device
|
||||
batch_size = quat.shape[0]
|
||||
|
||||
w = quat[:, 3].clamp(-1.0, 1.0)
|
||||
|
||||
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
|
||||
|
||||
result = torch.zeros((batch_size, 3), device=device)
|
||||
|
||||
mask = den > 1e-10
|
||||
|
||||
if mask.any():
|
||||
angle = 2.0 * torch.acos(w[mask]) # (M,)
|
||||
axis = quat[mask, :3] / den[mask].unsqueeze(1)
|
||||
result[mask] = axis * angle.unsqueeze(1)
|
||||
|
||||
return result
|
||||
@@ -71,7 +71,7 @@ from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||
from lerobot.envs.utils import (
|
||||
add_envs_task,
|
||||
check_env_attributes_and_types,
|
||||
@@ -94,6 +94,8 @@ from lerobot.utils.utils import (
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
seeds: list[int] | None = None,
|
||||
@@ -165,11 +167,19 @@ def rollout(
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
|
||||
observation = env_preprocessor(observation)
|
||||
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {"action": action}
|
||||
action_transition = env_postprocessor(action_transition)
|
||||
action = action_transition["action"]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
@@ -239,6 +249,8 @@ def rollout(
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
n_episodes: int,
|
||||
@@ -319,6 +331,8 @@ def eval_policy(
|
||||
rollout_data = rollout(
|
||||
env=env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
seeds=list(seeds) if seeds else None,
|
||||
@@ -517,10 +531,16 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
@@ -561,6 +581,8 @@ def eval_one(
|
||||
env: gym.vector.VectorEnv,
|
||||
*,
|
||||
policy: PreTrainedPolicy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
n_episodes: int,
|
||||
@@ -576,6 +598,8 @@ def eval_one(
|
||||
task_result = eval_policy(
|
||||
env=env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
@@ -600,6 +624,8 @@ def run_one(
|
||||
env,
|
||||
*,
|
||||
policy,
|
||||
env_preprocessor,
|
||||
env_postprocessor,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
n_episodes: int,
|
||||
@@ -622,6 +648,8 @@ def run_one(
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
@@ -639,6 +667,8 @@ def run_one(
|
||||
def eval_policy_all(
|
||||
envs: dict[str, dict[int, gym.vector.VectorEnv]],
|
||||
policy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
n_episodes: int,
|
||||
@@ -694,6 +724,8 @@ def eval_policy_all(
|
||||
task_runner = partial(
|
||||
run_one,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
|
||||
@@ -29,7 +29,7 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||
from lerobot.envs.utils import close_envs
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
@@ -259,6 +259,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info("Creating environment processors")
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
@@ -385,6 +387,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
eval_info = eval_policy_all(
|
||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
seed = 42
|
||||
np.random.seed(seed)
|
||||
|
||||
B = 5
|
||||
obs1 = {
|
||||
"pixels": {
|
||||
"image": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
|
||||
"image2": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
|
||||
},
|
||||
"robot_state": {
|
||||
"eef": {
|
||||
"pos": np.random.randn(B, 3),
|
||||
"quat": np.random.randn(B, 4),
|
||||
"mat": np.random.randn(B, 3, 3),
|
||||
"axisangle": np.random.randn(B, 3),
|
||||
},
|
||||
"gripper": {
|
||||
"qpos": np.random.randn(B, 2),
|
||||
"qvel": np.random.randn(B, 2),
|
||||
},
|
||||
"joints": {
|
||||
"pos": np.random.randn(B, 7),
|
||||
"vel": np.random.randn(B, 7),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
observation = preprocess_observation(obs1)
|
||||
libero_preprocessor = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
LiberoProcessorStep(),
|
||||
]
|
||||
)
|
||||
processed_obs = libero_preprocessor(observation)
|
||||
assert "observation.state" in processed_obs
|
||||
state = processed_obs["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.dtype == torch.float32
|
||||
|
||||
assert state.shape[0] == B
|
||||
assert state.shape[1] == 8
|
||||
|
||||
assert "observation.images.image" in processed_obs
|
||||
assert "observation.images.image2" in processed_obs
|
||||
|
||||
assert isinstance(processed_obs["observation.images.image"], torch.Tensor)
|
||||
assert isinstance(processed_obs["observation.images.image2"], torch.Tensor)
|
||||
|
||||
assert processed_obs["observation.images.image"].shape == (B, 3, 256, 256)
|
||||
assert processed_obs["observation.images.image2"].shape == (B, 3, 256, 256)
|
||||
Reference in New Issue
Block a user