mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 15960f0b5e | |||
| 8b43339563 | |||
| 5dababd21e | |||
| cbc46467b3 | |||
| e881fb6678 | |||
| acf0ba7fb3 | |||
| a74b90edd1 | |||
| 846677f9cc | |||
| af9ddcf9a2 | |||
| d602e8169c | |||
| 49baccdccb | |||
| d32006440c | |||
| f1cfdfced9 | |||
| 6a3d57031a | |||
| d74494d92b | |||
| 888a5b6249 | |||
| f247aa0701 | |||
| 1ac6a6d3fe | |||
| e698c709d8 | |||
| a988da4789 | |||
| 99963b6968 | |||
| 332ca4ccc5 | |||
| fc43246942 | |||
| 793ad86fc9 | |||
| a6dbb65917 | |||
| 6c7169c4af | |||
| f125d5e3bf | |||
| 75dcfd4886 | |||
| ff3cbaa872 | |||
| ce793cde64 | |||
| 029c4a9a76 | |||
| d893bf1e30 | |||
| 8c796b39f5 | |||
| 4ebe482a7e | |||
| 2fcc358e98 | |||
| b052843f08 | |||
| ebb464c255 | |||
| 2914ae2a96 | |||
| 645c87e3a9 |
@@ -35,6 +35,8 @@
|
||||
title: Koch v1.1
|
||||
- local: lekiwi
|
||||
title: LeKiwi
|
||||
- local: reachy2
|
||||
title: Reachy 2
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
|
||||
+13
-13
@@ -143,27 +143,27 @@ HIL-SERL uses a modular processor pipeline architecture that processes robot obs
|
||||
|
||||
The environment processor (`env_processor`) handles incoming observations and environment state:
|
||||
|
||||
1. **VanillaObservationProcessor**: Converts raw robot observations into standardized format
|
||||
2. **JointVelocityProcessor** (optional): Adds joint velocity information to observations
|
||||
3. **MotorCurrentProcessor** (optional): Adds motor current readings to observations
|
||||
1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format
|
||||
2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations
|
||||
3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations
|
||||
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
|
||||
5. **ImageCropResizeProcessor** (optional): Crops and resizes camera images
|
||||
6. **TimeLimitProcessor** (optional): Enforces episode time limits
|
||||
7. **GripperPenaltyProcessor** (optional): Applies penalties for inappropriate gripper usage
|
||||
8. **RewardClassifierProcessor** (optional): Automated reward detection using vision models
|
||||
9. **ToBatchProcessor**: Converts data to batch format for neural network processing
|
||||
10. **DeviceProcessor**: Moves data to the specified compute device (CPU/GPU)
|
||||
5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images
|
||||
6. **TimeLimitProcessorStep** (optional): Enforces episode time limits
|
||||
7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage
|
||||
8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models
|
||||
9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing
|
||||
10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU)
|
||||
|
||||
#### Action Processor Pipeline
|
||||
|
||||
The action processor (`action_processor`) handles outgoing actions and human interventions:
|
||||
|
||||
1. **AddTeleopActionAsComplimentaryData**: Captures teleoperator actions for logging
|
||||
2. **AddTeleopEventsAsInfo**: Records intervention events and episode control signals
|
||||
1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging
|
||||
2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals
|
||||
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
|
||||
4. **InterventionActionProcessor**: Handles human interventions and episode termination
|
||||
4. **InterventionActionProcessorStep**: Handles human interventions and episode termination
|
||||
5. **Inverse Kinematics Pipeline** (when enabled):
|
||||
- **MapDeltaActionToRobotAction**: Converts delta actions to robot action format
|
||||
- **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format
|
||||
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
|
||||
- **EEBoundsAndSafety**: Enforces workspace safety bounds
|
||||
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
# Reachy 2
|
||||
|
||||
Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications.
|
||||
Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform!
|
||||
|
||||
## Teleoperate Reachy 2
|
||||
|
||||
Currently, there are two ways to teleoperate Reachy 2:
|
||||
|
||||
- Pollen Robotics’ VR teleoperation (not included in LeRobot).
|
||||
- Robot-to-robot teleoperation (use one Reachy 2 to control another).
|
||||
|
||||
## Reachy 2 Simulation
|
||||
|
||||
**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core).
|
||||
|
||||
1. Install [Docker Engine](https://docs.docker.com/engine/).
|
||||
2. Run (for MuJoCo):
|
||||
|
||||
```
|
||||
docker run --rm -it \
|
||||
--name reachy \
|
||||
--privileged \
|
||||
--network host \
|
||||
--ipc host \
|
||||
--device-cgroup-rule='c 189:* rwm' \
|
||||
--group-add audio \
|
||||
-e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
-e DISPLAY="$DISPLAY" \
|
||||
-e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
-e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
-v /dev:/dev \
|
||||
-v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
-v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
-v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
--entrypoint /package/launch.sh \
|
||||
pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
```
|
||||
|
||||
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
|
||||
>
|
||||
> ```
|
||||
> docker run --rm -it \
|
||||
> --name reachy \
|
||||
> --privileged \
|
||||
> --network host \
|
||||
> --ipc host \
|
||||
> --device-cgroup-rule='c 189:* rwm' \
|
||||
> --group-add audio \
|
||||
> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \
|
||||
> -e DISPLAY="$DISPLAY" \
|
||||
> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \
|
||||
> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \
|
||||
> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \
|
||||
> -v /dev:/dev \
|
||||
> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \
|
||||
> -v "$HOME/.reachy.log":/home/reachy/.ros/log \
|
||||
> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \
|
||||
> --entrypoint /package/launch.sh \
|
||||
> pollenrobotics/reachy2_core:1.7.5.9_deploy \
|
||||
> start_rviz:=true start_sdk_server:=true mujoco:=true
|
||||
> ```
|
||||
|
||||
## Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- On your robot, check the **service images** meet the minimum versions:
|
||||
- **reachy2-core >= 1.7.5.2**
|
||||
- **webrtc >= 2.0.1.1**
|
||||
|
||||
Then, if you want to use VR teleoperation:
|
||||
|
||||
- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/).
|
||||
Use version **>=v1.2.0**
|
||||
|
||||
We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot.
|
||||
|
||||
### Install LeRobot
|
||||
|
||||
Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot.
|
||||
|
||||
Install LeRobot with Reachy 2 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[reachy2]"
|
||||
```
|
||||
|
||||
### (Optional but recommended) Install pollen_data_acquisition_server
|
||||
|
||||
How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app.
|
||||
|
||||
> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. You’re free to develop custom clients to manage sessions to your needs.
|
||||
|
||||
In your LeRobot environment, install the server from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git
|
||||
cd pollen_data_acquisition_server
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server).
|
||||
|
||||
## Step 1: Recording
|
||||
|
||||
### Get Reachy 2 IP address
|
||||
|
||||
Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/).
|
||||
We strongly recommend connecting all devices (PC and robot) via **Ethernet**.
|
||||
|
||||
### Launch recording
|
||||
|
||||
There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application:
|
||||
|
||||
- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robot’s motions.
|
||||
- **Using LeRobot’s record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, it’s only for motion control.
|
||||
|
||||
### Option 1: Using Pollen data acquisition server (recommended for VR teleop)
|
||||
|
||||
Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section.
|
||||
|
||||
Launch the data acquisition server to be able to manage your session directly from the teleoperation application:
|
||||
|
||||
```bash
|
||||
python -m pollen_data_acquisition_server.server
|
||||
```
|
||||
|
||||
Then get into the teleoperation application and choose "Data acquisition session".
|
||||
You can finally setup your session by following the screens displayed.
|
||||
|
||||
> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation.
|
||||
|
||||
### Option 2: Using lerobot.record
|
||||
|
||||
Reachy 2 is fully supported by LeRobot’s recording features.
|
||||
If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app.
|
||||
|
||||
**Example: start a recording without the mobile base:**
|
||||
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.id=r2-0000 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=false \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
#### Specific Options
|
||||
|
||||
**Extended setup overview (all options included):**
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=true \
|
||||
--robot.with_mobile_base=true \
|
||||
--robot.with_l_arm=true \
|
||||
--robot.with_r_arm=true \
|
||||
--robot.with_neck=true \
|
||||
--robot.with_antennas=true \
|
||||
--robot.with_left_teleop_camera=true \
|
||||
--robot.with_right_teleop_camera=true \
|
||||
--robot.with_torso_camera=false \
|
||||
--robot.disable_torque_on_disconnect=false \
|
||||
--robot.max_relative_target=5.0 \
|
||||
--teleop.type=reachy2_teleoperator \
|
||||
--teleop.ip_address=192.168.0.200 \
|
||||
--teleop.use_present_position=false \
|
||||
--teleop.with_mobile_base=false \
|
||||
--teleop.with_l_arm=true \
|
||||
--teleop.with_r_arm=true \
|
||||
--teleop.with_neck=true \
|
||||
--teleop.with_antennas=true \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.single_task="Reachy 2 recording test" \
|
||||
--dataset.num_episodes=1 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
##### `--robot.use_external_commands`
|
||||
|
||||
Determine whether LeRobot robot.send_action() sends commands to the robot.
|
||||
**Must** be set to false while using the VR teleoperation application, as the app already sends commands.
|
||||
|
||||
##### `--teleop.use_present_position`
|
||||
|
||||
Determine whether the teleoperator reads the goal or present position of the robot.
|
||||
Must be set to true if a compliant Reachy 2 is used to control another one.
|
||||
|
||||
##### Use the relevant parts
|
||||
|
||||
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
|
||||
To avoid this, you can exclude specific parts from recording and replay using:
|
||||
|
||||
````
|
||||
--robot.with_<part>=false
|
||||
```,
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
It determine whether the corresponding part is recorded in the observations. True if not set.
|
||||
|
||||
By default, **all parts are recorded**.
|
||||
|
||||
The same per-part mechanism is available in `reachy2_teleoperator` as well.
|
||||
|
||||
````
|
||||
|
||||
--teleop.with\_<part>
|
||||
|
||||
```
|
||||
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
|
||||
Determine whether the corresponding part is recorded in the actions. True if not set.
|
||||
|
||||
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
|
||||
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
|
||||
|
||||
##### Use the relevant cameras
|
||||
|
||||
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with:
|
||||
|
||||
```
|
||||
|
||||
--robot.with_left_teleop_camera=<true|false>
|
||||
--robot.with_right_teleop_camera=<true|false>
|
||||
--robot.with_torso_camera=<true|false>
|
||||
|
||||
````
|
||||
|
||||
|
||||
## Step 2: Replay
|
||||
|
||||
Make sure the robot is configured with the same parts as the dataset:
|
||||
|
||||
```bash
|
||||
python -m lerobot.replay \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--robot.use_external_commands=false \
|
||||
--robot.with_mobile_base=false \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--dataset.episode=0
|
||||
--display_data=true
|
||||
````
|
||||
|
||||
## Step 3: Train
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
--dataset.repo_id=pollen_robotics/record_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/reachy2_test \
|
||||
--job_name=reachy2 \
|
||||
--policy.device=mps \
|
||||
--wandb.enable=true \
|
||||
--policy.repo_id=pollen_robotics/record_test_policy
|
||||
```
|
||||
|
||||
## Step 4: Evaluate
|
||||
|
||||
```bash
|
||||
python -m lerobot.record \
|
||||
--robot.type=reachy2 \
|
||||
--robot.ip_address=192.168.0.200 \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=pollen_robotics/eval_record_test \
|
||||
--dataset.single_task="Evaluate reachy2 policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
@@ -16,16 +16,17 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
identity_transition,
|
||||
observation_to_transition,
|
||||
transition_to_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
@@ -65,7 +66,7 @@ kinematics_solver = RobotKinematics(
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
@@ -74,36 +75,36 @@ robot_ee_to_joints = RobotProcessor(
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
to_transition=identity_transition,
|
||||
to_output=transition_to_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
to_transition=observation_to_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
# Build dataset action and gripper features
|
||||
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features=create_initial_features(),
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
||||
) # Get all ee action features + gripper pos action features
|
||||
|
||||
# Build dataset observation features
|
||||
obs_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
) # Get all ee observation features
|
||||
|
||||
dataset_features = merge_features(obs_ee, action_ee_and_gripper)
|
||||
dataset_features = combine_feature_dicts(obs_ee, action_ee_and_gripper)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
@@ -147,8 +148,8 @@ for episode_idx in range(NUM_EPISODES):
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
|
||||
@@ -17,15 +17,16 @@
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import merge_features
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
action_to_transition,
|
||||
identity_transition,
|
||||
observation_to_transition,
|
||||
transition_to_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
@@ -73,7 +74,7 @@ kinematics_solver = RobotKinematics(
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose = RobotProcessor(
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
@@ -88,12 +89,12 @@ phone_to_robot_ee_pose = RobotProcessor(
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=lambda tr: tr,
|
||||
to_transition=action_to_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
@@ -105,31 +106,31 @@ robot_ee_to_joints = RobotProcessor(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action,
|
||||
to_transition=identity_transition,
|
||||
to_output=transition_to_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessor(
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=to_transition_robot_observation,
|
||||
to_output=lambda tr: tr,
|
||||
to_transition=observation_to_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
# Build dataset ee action features
|
||||
action_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose,
|
||||
initial_features=phone.action_features,
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
patterns=["action.ee"],
|
||||
)
|
||||
|
||||
# Get gripper pos action features
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints,
|
||||
initial_features={},
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features=create_initial_features(),
|
||||
use_videos=True,
|
||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||
)
|
||||
@@ -137,12 +138,12 @@ gripper = aggregate_pipeline_dataset_features(
|
||||
# Build dataset ee observation features
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
)
|
||||
|
||||
dataset_features = merge_features(action_ee, gripper, observation_ee)
|
||||
dataset_features = combine_feature_dicts(action_ee, gripper, observation_ee)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
@@ -177,8 +178,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
@@ -193,8 +194,8 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose,
|
||||
robot_action_processor=robot_ee_to_joints,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
@@ -50,7 +50,7 @@ kinematics_solver = RobotKinematics(
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints = RobotProcessor(
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
@@ -59,11 +59,11 @@ robot_ee_to_joints = RobotProcessor(
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=to_output_robot_action,
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_action,
|
||||
)
|
||||
|
||||
robot_ee_to_joints.reset()
|
||||
robot_ee_to_joints_processor.reset()
|
||||
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
@@ -73,7 +73,7 @@ for idx in range(dataset.num_frames):
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
}
|
||||
|
||||
joint_action = robot_ee_to_joints(ee_action)
|
||||
joint_action = robot_ee_to_joints_processor(ee_action)
|
||||
action_sent = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
@@ -49,7 +49,7 @@ kinematics_solver = RobotKinematics(
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints = RobotProcessor(
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
@@ -72,8 +72,8 @@ phone_to_robot_joints = RobotProcessor(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=to_output_robot_action,
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_action,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
@@ -85,7 +85,7 @@ while True:
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints(phone_obs)
|
||||
joint_action = phone_to_robot_joints_processor(phone_obs)
|
||||
|
||||
if joint_action:
|
||||
robot.send_action(joint_action)
|
||||
|
||||
+2
-1
@@ -73,7 +73,6 @@ dependencies = [
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"scipy>=1.15.2",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
@@ -107,6 +106,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
@@ -143,6 +143,7 @@ all = [
|
||||
"lerobot[gamepad]",
|
||||
"lerobot[hopejr]",
|
||||
"lerobot[lekiwi]",
|
||||
"lerobot[reachy2]",
|
||||
"lerobot[kinematics]",
|
||||
"lerobot[intelrealsense]",
|
||||
"lerobot[pi0]",
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_reachy2_camera import Reachy2CameraConfig
|
||||
from .reachy2_camera import Reachy2Camera
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("reachy2_camera")
|
||||
@dataclass
|
||||
class Reachy2CameraConfig(CameraConfig):
|
||||
"""Configuration class for Reachy 2 camera devices.
|
||||
|
||||
This class provides configuration options for Reachy 2 cameras,
|
||||
supporting both the teleop and depth cameras. It includes settings
|
||||
for resolution, frame rate, color mode, and the selection of the cameras.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address="192.168.0.200", # IP address of the robot
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
) # Left teleop camera, 640x480 @ 15FPS
|
||||
```
|
||||
|
||||
Attributes:
|
||||
name: Name of the camera device. Can be "teleop" or "depth".
|
||||
image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
|
||||
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
|
||||
fps: Requested frames per second for the color stream.
|
||||
width: Requested frame width in pixels for the color stream.
|
||||
height: Requested frame height in pixels for the color stream.
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
ip_address: IP address of the robot. Defaults to "localhost".
|
||||
port: Port number for the camera server. Defaults to 50065.
|
||||
|
||||
Note:
|
||||
- Only 3-channel color output (RGB/BGR) is currently supported.
|
||||
"""
|
||||
|
||||
name: str
|
||||
image_type: str
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
ip_address: str | None = "localhost"
|
||||
port: int = 50065
|
||||
# use_depth: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.name not in ["teleop", "depth"]:
|
||||
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
|
||||
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (
|
||||
self.name == "depth" and self.image_type not in ["rgb", "depth"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided."
|
||||
)
|
||||
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
@@ -0,0 +1,288 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
# Fix MSMF hardware transform compatibility for Windows before importing cv2
|
||||
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
|
||||
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Reachy2Camera(Camera):
|
||||
"""
|
||||
Manages Reachy 2 camera using Reachy 2 CameraManager.
|
||||
|
||||
This class provides a high-level interface to connect to, configure, and read
|
||||
frames from Reachy 2 cameras. It supports both synchronous and asynchronous
|
||||
frame reading.
|
||||
|
||||
An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image
|
||||
type (e.g., "left") to be specified in the configuration.
|
||||
|
||||
The camera's default settings (FPS, resolution, color mode) are used unless
|
||||
overridden in the configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Reachy2CameraConfig):
|
||||
"""
|
||||
Initializes the Reachy2Camera instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.fps = config.fps
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
self.cam_manager: CameraManager | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: np.ndarray | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected and opened."""
|
||||
if self.config.name == "teleop":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
|
||||
elif self.config.name == "depth":
|
||||
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
|
||||
else:
|
||||
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
|
||||
|
||||
def connect(self, warmup: bool = True):
|
||||
"""
|
||||
Connects to the Reachy2 CameraManager as specified in the configuration.
|
||||
"""
|
||||
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
|
||||
self.cam_manager.initialize_cameras()
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Detects available Reachy 2 cameras.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains 'name', 'stereo',
|
||||
and the default profile properties (width, height, fps).
|
||||
"""
|
||||
initialized_cameras = []
|
||||
camera_manager = CameraManager(host=ip_address, port=port)
|
||||
|
||||
for camera in [camera_manager.teleop, camera_manager.depth]:
|
||||
if camera is None:
|
||||
continue
|
||||
|
||||
height, width, _, _, _, _, _ = camera.get_parameters()
|
||||
|
||||
camera_info = {
|
||||
"name": camera._cam_info.name,
|
||||
"stereo": camera._cam_info.stereo,
|
||||
"default_profile": {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"fps": 30,
|
||||
},
|
||||
}
|
||||
initialized_cameras.append(camera_info)
|
||||
|
||||
camera_manager.disconnect()
|
||||
return initialized_cameras
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
|
||||
"""
|
||||
Reads a single frame synchronously from the camera.
|
||||
|
||||
This is a blocking call.
|
||||
|
||||
Args:
|
||||
color_mode (Optional[ColorMode]): If specified, overrides the default
|
||||
color mode (`self.color_mode`) for this read operation (e.g.,
|
||||
request RGB even if default is BGR).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The captured frame as a NumPy array in the format
|
||||
(height, width, channels), using the specified or default
|
||||
color mode and applying any configured rotation.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
frame = None
|
||||
|
||||
if self.cam_manager is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
else:
|
||||
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
if self.config.image_type == "left":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0]
|
||||
elif self.config.image_type == "right":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0]
|
||||
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
if self.config.image_type == "depth":
|
||||
frame = self.cam_manager.depth.get_depth_frame()[0]
|
||||
elif self.config.image_type == "rgb":
|
||||
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0]
|
||||
|
||||
if frame is None:
|
||||
return np.empty((0, 0, 3), dtype=np.uint8)
|
||||
|
||||
if self.config.color_mode == "rgb":
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a color frame
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
color_image = self.read()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = color_image
|
||||
self.new_frame_event.set()
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for the camera hardware directly,
|
||||
but may wait up to timeout_ms for the background thread to provide a frame.
|
||||
|
||||
Args:
|
||||
timeout_ms (float): Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 200ms (0.2 seconds).
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
(height, width, channels), processed according to configuration.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Stops the background read thread (if running).
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
if self.cam_manager is not None:
|
||||
self.cam_manager.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -37,8 +37,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
from .realsense.camera_realsense import RealSenseCamera
|
||||
|
||||
cameras[key] = RealSenseCamera(cfg)
|
||||
|
||||
elif cfg.type == "reachy2_camera":
|
||||
from .reachy2_camera.reachy2_camera import Reachy2Camera
|
||||
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
@@ -27,6 +27,11 @@ class FeatureType(str, Enum):
|
||||
LANGUAGE = "LANGUAGE"
|
||||
|
||||
|
||||
class PipelineFeatureType(str, Enum):
|
||||
ACTION = "ACTION"
|
||||
OBSERVATION = "OBSERVATION"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
MIN_MAX = "MIN_MAX"
|
||||
MEAN_STD = "MEAN_STD"
|
||||
|
||||
@@ -45,8 +45,8 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
|
||||
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor"
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor"
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
|
||||
@@ -12,84 +12,130 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
action: dict[str, Any] | None, observation: dict[str, Any] | None
|
||||
) -> dict[PipelineFeatureType, dict[str, Any]]:
|
||||
"""
|
||||
Creates the initial features dict for the dataset from action and observation specs.
|
||||
|
||||
Args:
|
||||
action: A dictionary of action feature names to their types/shapes.
|
||||
observation: A dictionary of observation feature names to their types/shapes.
|
||||
|
||||
Returns:
|
||||
The initial features dictionary structured by PipelineFeatureType.
|
||||
"""
|
||||
features = {PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}}
|
||||
if action:
|
||||
features[PipelineFeatureType.ACTION] = action
|
||||
if observation:
|
||||
features[PipelineFeatureType.OBSERVATION] = observation
|
||||
return features
|
||||
|
||||
|
||||
# Helper to filter state/action keys based on regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[str]) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
|
||||
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
|
||||
for prefix in prefixes_to_strip:
|
||||
if key.startswith(prefix):
|
||||
return key[len(prefix) :]
|
||||
return key
|
||||
|
||||
|
||||
# Define prefixes to strip from feature keys for clean names.
|
||||
# Handles both fully qualified (e.g., "action.state") and short (e.g., "state") forms.
|
||||
PREFIXES_TO_STRIP = tuple(
|
||||
f"{token}." for const in (ACTION, OBS_STATE, OBS_IMAGES) for token in (const, const.split(".")[-1])
|
||||
)
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: RobotProcessor,
|
||||
initial_features: dict[str, Any],
|
||||
pipeline: DataProcessorPipeline,
|
||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates the pipeline's features and returns a features dict ready for the dataset,
|
||||
filtered to only those keys matching any of the given patterns (for action/state only).
|
||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||
|
||||
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
|
||||
- `use_videos`: whether to treat image features as video streams
|
||||
- `patterns`: regexes to filter action & state features; images are included
|
||||
whenever use_videos=True, regardless of patterns.
|
||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
||||
formats them for use with a Hugging Face LeRobot Dataset.
|
||||
|
||||
Args:
|
||||
pipeline: The DataProcessorPipeline to apply.
|
||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||
use_videos: If False, image features are excluded.
|
||||
patterns: A sequence of regex patterns to filter action and state features.
|
||||
Image features are not affected by this filter.
|
||||
|
||||
Returns:
|
||||
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Gather everything the pipeline features specifies, seeded with hardware cams:
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Helper to decide which action/state keys survive the `patterns` filter:
|
||||
def keep(key: str) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
processed_features: dict[str, dict[str, Any]] = {
|
||||
"action": {},
|
||||
"observation": {},
|
||||
}
|
||||
images_token = OBS_IMAGES.split(".")[-1]
|
||||
|
||||
# Start with hardware dict, injecting initial cameras if videos are ON:
|
||||
hw: dict[str, dict[str, Any]] = {}
|
||||
if use_videos:
|
||||
cams = {
|
||||
name: shape
|
||||
for name, shape in initial_features.items()
|
||||
if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
if cams:
|
||||
hw["observation"] = dict(cams)
|
||||
|
||||
# Go over every feature from the pipeline and merge:
|
||||
for full_key, ty in all_features.items():
|
||||
if full_key.startswith(f"{ACTION}."):
|
||||
# action.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{ACTION}.") :]
|
||||
hw.setdefault(ACTION, {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_STATE}."):
|
||||
# observation.state.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{OBS_STATE}.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_IMAGES}."):
|
||||
# observation.images.<cam>
|
||||
# images obey ONLY the use_videos flag, not patterns
|
||||
if not use_videos:
|
||||
continue
|
||||
name = full_key[len(f"{OBS_IMAGES}.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
else:
|
||||
# anything else (e.g. policy-only features) is ignored here
|
||||
# Iterate through all features transformed by the pipeline.
|
||||
for ptype, feats in all_features.items():
|
||||
if ptype not in [PipelineFeatureType.ACTION, PipelineFeatureType.OBSERVATION]:
|
||||
continue
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
if ACTION in hw:
|
||||
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
|
||||
if "observation" in hw:
|
||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||
for key, value in feats.items():
|
||||
# 1. Categorize the feature.
|
||||
is_action = ptype == PipelineFeatureType.ACTION
|
||||
# Observations are classified as images if their key matches image-related tokens or if the shape of the feature is 3.
|
||||
# All other observations are treated as state.
|
||||
is_image = not is_action and (
|
||||
(isinstance(value, tuple) and len(value) == 3)
|
||||
or (
|
||||
key.startswith(f"{OBS_IMAGES}.")
|
||||
or key.startswith(f"{images_token}.")
|
||||
or f".{images_token}." in key
|
||||
)
|
||||
)
|
||||
|
||||
return out
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not should_keep(key, patterns):
|
||||
continue
|
||||
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
name = strip_prefix(key, PREFIXES_TO_STRIP)
|
||||
if is_action:
|
||||
processed_features["action"][name] = value
|
||||
else:
|
||||
processed_features["observation"][name] = value
|
||||
|
||||
# Convert the processed features into the final dataset format.
|
||||
dataset_features = {}
|
||||
if processed_features["action"]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
|
||||
if processed_features["observation"]:
|
||||
dataset_features.update(
|
||||
hw_to_dataset_features(processed_features["observation"], "observation", use_videos)
|
||||
)
|
||||
|
||||
return dataset_features
|
||||
|
||||
+541
-45
@@ -75,13 +75,20 @@ DEFAULT_FEATURES = {
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
"""Flatten a nested dictionary by joining keys with a separator.
|
||||
|
||||
For example:
|
||||
```
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
>>> print(flatten_dict(dct))
|
||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
Example:
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
|
||||
>>> print(flatten_dict(dct))
|
||||
{'a/b': 1, 'a/c/d': 2, 'e': 3}
|
||||
|
||||
Args:
|
||||
d (dict): The dictionary to flatten.
|
||||
parent_key (str): The base key to prepend to the keys in this level.
|
||||
sep (str): The separator to use between keys.
|
||||
|
||||
Returns:
|
||||
dict: A flattened dictionary.
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
@@ -94,6 +101,20 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
|
||||
|
||||
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
"""Unflatten a dictionary with delimited keys into a nested dictionary.
|
||||
|
||||
Example:
|
||||
>>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
>>> print(unflatten_dict(flat_dct))
|
||||
{'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
|
||||
|
||||
Args:
|
||||
d (dict): A dictionary with flattened keys.
|
||||
sep (str): The separator used in the keys.
|
||||
|
||||
Returns:
|
||||
dict: A nested dictionary.
|
||||
"""
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
@@ -107,6 +128,16 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
|
||||
|
||||
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
"""Access an item in a nested dictionary using a flattened key.
|
||||
|
||||
Args:
|
||||
obj (DictLike): The nested dictionary-like object.
|
||||
flattened_key (str): A key with parts separated by `sep`.
|
||||
sep (str): The separator used in the flattened key.
|
||||
|
||||
Returns:
|
||||
Any: The value from the nested dictionary.
|
||||
"""
|
||||
split_keys = flattened_key.split(sep)
|
||||
getter = obj[split_keys[0]]
|
||||
if len(split_keys) == 1:
|
||||
@@ -119,6 +150,19 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
"""Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible.
|
||||
|
||||
Converts torch.Tensor, np.ndarray, and np.generic types to lists or native Python types.
|
||||
|
||||
Args:
|
||||
stats (dict): A dictionary that may contain non-serializable numeric types.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with all values converted to JSON-serializable types.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If a value has an unsupported type.
|
||||
"""
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
@@ -133,6 +177,17 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
|
||||
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""Embed image bytes into the dataset table before saving to Parquet.
|
||||
|
||||
This function prepares a Hugging Face dataset for serialization by converting
|
||||
image objects into an embedded format that can be stored in Arrow/Parquet.
|
||||
|
||||
Args:
|
||||
dataset (datasets.Dataset): The input dataset, possibly containing image features.
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: The dataset with images embedded in the table storage.
|
||||
"""
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
@@ -142,38 +197,94 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
"""Load data from a JSON file.
|
||||
|
||||
Args:
|
||||
fpath (Path): Path to the JSON file.
|
||||
|
||||
Returns:
|
||||
Any: The data loaded from the JSON file.
|
||||
"""
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
"""Write data to a JSON file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The dictionary to write.
|
||||
fpath (Path): The path to the output JSON file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||
"""Load data from a JSON Lines file.
|
||||
|
||||
Args:
|
||||
fpath (Path): Path to the JSON Lines file.
|
||||
|
||||
Returns:
|
||||
list[Any]: A list of objects loaded from the file.
|
||||
"""
|
||||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
"""Write a list of dictionaries to a JSON Lines file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The list of dictionaries to write.
|
||||
fpath (Path): The path to the output JSON Lines file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
"""Append a dictionary to a JSON Lines file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The dictionary to append.
|
||||
fpath (Path): The path to the JSON Lines file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
"""Write dataset info metadata to its standard file path.
|
||||
|
||||
Args:
|
||||
info (dict): The dataset information dictionary.
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
"""Load dataset info metadata from its standard file path.
|
||||
|
||||
Also converts shape lists to tuples for consistency.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: The dataset information dictionary.
|
||||
"""
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
@@ -181,16 +292,40 @@ def load_info(local_dir: Path) -> dict:
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
"""Serialize and write dataset statistics to their standard file path.
|
||||
|
||||
Args:
|
||||
stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Recursively cast numerical values in a stats dictionary to numpy arrays.
|
||||
|
||||
Args:
|
||||
stats (dict): The statistics dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Load dataset statistics and cast numerical values to numpy arrays.
|
||||
|
||||
Returns None if the stats file doesn't exist.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
A dictionary of statistics or None if the file is not found.
|
||||
"""
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
@@ -198,6 +333,13 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
"""Write a single task to the tasks metadata file.
|
||||
|
||||
Args:
|
||||
task_index (int): The index of the task.
|
||||
task (dict): The task description dictionary.
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
@@ -206,6 +348,16 @@ def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
"""Load tasks from the tasks metadata file.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A dictionary mapping task index to task description.
|
||||
- A dictionary mapping task description to task index.
|
||||
"""
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
@@ -213,15 +365,36 @@ def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
"""Write a single episode's metadata to the episodes metadata file.
|
||||
|
||||
Args:
|
||||
episode (dict): The episode metadata dictionary.
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
"""Load episode metadata from the episodes metadata file.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping episode index to episode metadata.
|
||||
"""
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
"""Write statistics for a single episode to the episode stats file.
|
||||
|
||||
Args:
|
||||
episode_index (int): The index of the episode.
|
||||
episode_stats (dict): The statistics for the episode.
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
@@ -229,6 +402,14 @@ def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
"""Load per-episode statistics from the episode stats file.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping episode index to its statistics dictionary.
|
||||
"""
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
@@ -239,12 +420,35 @@ def load_episodes_stats(local_dir: Path) -> dict:
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Create a per-episode stats dictionary from a global stats dictionary.
|
||||
|
||||
This is used for backward compatibility with older datasets that only had global stats.
|
||||
|
||||
Args:
|
||||
stats (dict): The global dataset statistics.
|
||||
episodes (list[int]): A list of episode indices.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping each episode index to the global stats.
|
||||
"""
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Load an image from a file into a numpy array.
|
||||
|
||||
Args:
|
||||
fpath (str | Path): Path to the image file.
|
||||
dtype (np.dtype): The desired data type of the output array. If floating,
|
||||
pixels are scaled to [0, 1].
|
||||
channel_first (bool): If True, converts the image to (C, H, W) format.
|
||||
Otherwise, it remains in (H, W, C) format.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image as a numpy array.
|
||||
"""
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
@@ -255,10 +459,19 @@ def load_image_as_numpy(
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
with channel first (c h w) of float32 type in range [0,1].
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||
types are converted to torch.tensor.
|
||||
|
||||
Args:
|
||||
items_dict (dict): A dictionary representing a batch of data from a
|
||||
Hugging Face dataset.
|
||||
|
||||
Returns:
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
@@ -273,6 +486,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
"""Check if a string is a valid PEP 440 version.
|
||||
|
||||
Args:
|
||||
version (str): The version string to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the version string is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
packaging.version.parse(version)
|
||||
return True
|
||||
@@ -286,6 +507,18 @@ def check_version_compatibility(
|
||||
current_version: str | packaging.version.Version,
|
||||
enforce_breaking_major: bool = True,
|
||||
) -> None:
|
||||
"""Check for version compatibility between a dataset and the current codebase.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository ID for logging purposes.
|
||||
version_to_check (str | packaging.version.Version): The version of the dataset.
|
||||
current_version (str | packaging.version.Version): The current version of the codebase.
|
||||
enforce_breaking_major (bool): If True, raise an error on major version mismatch.
|
||||
|
||||
Raises:
|
||||
BackwardCompatibilityError: If the dataset version is from a newer, incompatible
|
||||
major version of the codebase.
|
||||
"""
|
||||
v_check = (
|
||||
packaging.version.parse(version_to_check)
|
||||
if not isinstance(version_to_check, packaging.version.Version)
|
||||
@@ -303,7 +536,14 @@ def check_version_compatibility(
|
||||
|
||||
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
"""Returns available valid versions (branches and tags) on given repo."""
|
||||
"""Return available valid versions (branches and tags) on a given Hub repo.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository ID on the Hugging Face Hub.
|
||||
|
||||
Returns:
|
||||
list[packaging.version.Version]: A list of valid versions found.
|
||||
"""
|
||||
api = HfApi()
|
||||
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||
@@ -316,9 +556,22 @@ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
|
||||
|
||||
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
||||
"""
|
||||
Returns the version if available on repo or the latest compatible one.
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""Return the specified version if available on repo, or the latest compatible one.
|
||||
|
||||
If the exact version is not found, it looks for the latest version with the
|
||||
same major version number that is less than or equal to the target minor version.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository ID on the Hugging Face Hub.
|
||||
version (str | packaging.version.Version): The target version.
|
||||
|
||||
Returns:
|
||||
str: The safe version string (e.g., "v1.2.3") to use as a revision.
|
||||
|
||||
Raises:
|
||||
RevisionNotFoundError: If the repo has no version tags.
|
||||
BackwardCompatibilityError: If only older major versions are available.
|
||||
ForwardCompatibilityError: If only newer major versions are available.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
@@ -360,6 +613,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
|
||||
|
||||
Args:
|
||||
features (dict): A LeRobot-style feature dictionary.
|
||||
|
||||
Returns:
|
||||
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
|
||||
|
||||
Raises:
|
||||
ValueError: If a feature has an unsupported shape.
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
@@ -387,6 +651,14 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
|
||||
|
||||
def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
"""Validate that feature names do not contain invalid characters.
|
||||
|
||||
Args:
|
||||
features (dict): The LeRobot features dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If any feature name contains '/'.
|
||||
"""
|
||||
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
||||
if invalid_features:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
||||
@@ -395,6 +667,22 @@ def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
def hw_to_dataset_features(
|
||||
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
||||
) -> dict[str, dict]:
|
||||
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
|
||||
|
||||
This function takes a dictionary describing hardware outputs (like joint states
|
||||
or camera image shapes) and formats it into the standard LeRobot feature
|
||||
specification.
|
||||
|
||||
Args:
|
||||
hw_features (dict): Dictionary mapping feature names to their type (float for
|
||||
joints) or shape (tuple for images).
|
||||
prefix (str): The prefix to add to the feature keys (e.g., "observation"
|
||||
or "action").
|
||||
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
||||
|
||||
Returns:
|
||||
dict: A LeRobot features dictionary.
|
||||
"""
|
||||
features = {}
|
||||
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
@@ -427,6 +715,20 @@ def hw_to_dataset_features(
|
||||
def build_dataset_frame(
|
||||
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Construct a single data frame from raw values based on dataset features.
|
||||
|
||||
A "frame" is a dictionary containing all the data for a single timestep,
|
||||
formatted as numpy arrays according to the feature specification.
|
||||
|
||||
Args:
|
||||
ds_features (dict): The LeRobot dataset features dictionary.
|
||||
values (dict): A dictionary of raw values from the hardware/environment.
|
||||
prefix (str): The prefix to filter features by (e.g., "observation"
|
||||
or "action").
|
||||
|
||||
Returns:
|
||||
dict: A dictionary representing a single frame of data.
|
||||
"""
|
||||
frame = {}
|
||||
for key, ft in ds_features.items():
|
||||
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
||||
@@ -440,6 +742,21 @@ def build_dataset_frame(
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert dataset features to policy features.
|
||||
|
||||
This function transforms the dataset's feature specification into a format
|
||||
that a policy can use, classifying features by type (e.g., visual, state,
|
||||
action) and ensuring correct shapes (e.g., channel-first for images).
|
||||
|
||||
Args:
|
||||
features (dict): The LeRobot dataset features dictionary.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If an image feature does not have a 3D shape.
|
||||
"""
|
||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||
policy_features = {}
|
||||
for key, ft in features.items():
|
||||
@@ -470,12 +787,20 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
return policy_features
|
||||
|
||||
|
||||
def merge_features(*dicts: dict) -> dict:
|
||||
"""
|
||||
Merge LeRobot grouped feature dicts.
|
||||
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||
"""Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (observation.images.*), last one wins (if they are identical).
|
||||
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
|
||||
|
||||
Args:
|
||||
*dicts: A variable number of LeRobot feature dictionaries to merge.
|
||||
|
||||
Returns:
|
||||
dict: A single merged feature dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If there's a dtype mismatch for a feature being merged.
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
@@ -521,6 +846,18 @@ def create_empty_dataset_info(
|
||||
use_videos: bool,
|
||||
robot_type: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
Args:
|
||||
codebase_version (str): The version of the LeRobot codebase.
|
||||
fps (int): The frames per second of the data.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
"""
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
@@ -541,6 +878,18 @@ def create_empty_dataset_info(
|
||||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Calculate the start and end indices for each episode in a flattened dataset.
|
||||
|
||||
Args:
|
||||
episode_dicts (dict): A dictionary mapping episode index to episode metadata,
|
||||
which must contain a "length" key.
|
||||
episodes (list[int] | None): An optional list of episode indices to consider.
|
||||
If None, all episodes are used.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with "from" and "to" keys, containing torch tensors
|
||||
with the start and end indices for each episode.
|
||||
"""
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
@@ -560,16 +909,19 @@ def check_timestamps_sync(
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
||||
to account for possible numerical error.
|
||||
"""Check if timestamps are separated by (1/fps) +/- tolerance.
|
||||
|
||||
This check ensures that consecutive timestamps within an episode are spaced
|
||||
correctly, accounting for possible numerical errors. It ignores the boundaries
|
||||
between episodes.
|
||||
|
||||
Args:
|
||||
timestamps (np.ndarray): Array of timestamps in seconds.
|
||||
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
||||
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
||||
episode_data_index (dict): A dictionary that includes 'to',
|
||||
which identifies indices for the end of each episode.
|
||||
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
||||
fps (int): Frames per second. Used to check the expected difference between
|
||||
consecutive timestamps.
|
||||
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
||||
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
||||
|
||||
@@ -577,7 +929,8 @@ def check_timestamps_sync(
|
||||
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If the check fails and `raise_value_error` is True.
|
||||
ValueError: If `timestamps` and `episode_indices` shapes do not match, or if
|
||||
the check fails and `raise_value_error` is True.
|
||||
"""
|
||||
if timestamps.shape != episode_indices.shape:
|
||||
raise ValueError(
|
||||
@@ -628,9 +981,23 @@ def check_timestamps_sync(
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
actual timestamps from the dataset.
|
||||
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
|
||||
|
||||
This ensures that adding these delta timestamps to any existing timestamp in
|
||||
the dataset will result in a value that aligns with the dataset's frame rate.
|
||||
|
||||
Args:
|
||||
delta_timestamps (dict): A dictionary where values are lists of time
|
||||
deltas in seconds.
|
||||
fps (int): The frames per second of the dataset.
|
||||
tolerance_s (float): The allowed tolerance in seconds.
|
||||
raise_value_error (bool): If True, raises an error on failure.
|
||||
|
||||
Returns:
|
||||
bool: True if all deltas are valid, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
@@ -656,6 +1023,15 @@ def check_delta_timestamps(
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
"""Convert delta timestamps in seconds to delta indices in frames.
|
||||
|
||||
Args:
|
||||
delta_timestamps (dict): A dictionary of time deltas in seconds.
|
||||
fps (int): The frames per second of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of frame delta indices.
|
||||
"""
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
@@ -664,9 +1040,17 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||
"""Create a dataloader-safe cyclical iterator.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
||||
This is an equivalent of `itertools.cycle` but is safe for use with
|
||||
PyTorch DataLoaders with multiple workers.
|
||||
See https://github.com/pytorch/pytorch/issues/23900 for details.
|
||||
|
||||
Args:
|
||||
iterable: The iterable to cycle over.
|
||||
|
||||
Yields:
|
||||
Items from the iterable, restarting from the beginning when exhausted.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
@@ -677,8 +1061,14 @@ def cycle(iterable):
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""Create a branch on an existing Hugging Face repo.
|
||||
|
||||
Deletes the branch if it already exists before creating it.
|
||||
|
||||
Args:
|
||||
repo_id (str): The ID of the repository.
|
||||
branch (str): The name of the branch to create.
|
||||
repo_type (str | None): The type of the repository (e.g., "dataset").
|
||||
"""
|
||||
api = HfApi()
|
||||
|
||||
@@ -696,9 +1086,20 @@ def create_lerobot_dataset_card(
|
||||
dataset_info: dict | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
"""
|
||||
Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md.
|
||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||
"""Create a `DatasetCard` for a LeRobot dataset.
|
||||
|
||||
Keyword arguments are used to replace values in the card template.
|
||||
Note: If specified, `license` must be a valid license identifier from
|
||||
https://huggingface.co/docs/hub/repositories-licenses.
|
||||
|
||||
Args:
|
||||
tags (list | None): A list of tags to add to the dataset card.
|
||||
dataset_info (dict | None): The dataset's info dictionary, which will
|
||||
be displayed on the card.
|
||||
**kwargs: Additional keyword arguments to populate the card template.
|
||||
|
||||
Returns:
|
||||
DatasetCard: The generated dataset card object.
|
||||
"""
|
||||
card_tags = ["LeRobot"]
|
||||
|
||||
@@ -730,19 +1131,16 @@ def create_lerobot_dataset_card(
|
||||
|
||||
|
||||
class IterableNamespace(SimpleNamespace):
|
||||
"""
|
||||
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||
Automatically converts nested dictionaries into IterableNamespaces.
|
||||
"""A namespace object that supports both dictionary-like iteration and dot notation.
|
||||
|
||||
This class extends SimpleNamespace to provide:
|
||||
- Dictionary-style iteration over keys
|
||||
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||
- Dictionary-like methods: items(), keys(), values()
|
||||
- Recursive conversion of nested dictionaries
|
||||
This class extends `SimpleNamespace` to provide dictionary-style iteration,
|
||||
access to items via brackets (`obj["key"]`), and dictionary-like methods
|
||||
(`items()`, `keys()`, `values()`). Nested dictionaries are recursively
|
||||
converted to `IterableNamespace` objects.
|
||||
|
||||
Args:
|
||||
dictionary: Optional dictionary to initialize the namespace
|
||||
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||
dictionary (dict, optional): A dictionary to initialize the namespace with.
|
||||
**kwargs: Additional keyword arguments to initialize the namespace.
|
||||
|
||||
Examples:
|
||||
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||
@@ -756,10 +1154,16 @@ class IterableNamespace(SimpleNamespace):
|
||||
>>> for key, value in ns.items():
|
||||
... print(f"{key}: {value}")
|
||||
name: Alice
|
||||
details: IterableNamespace(age=25)
|
||||
details: <__main__.IterableNamespace object at ...>
|
||||
"""
|
||||
|
||||
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||
"""Initialize the IterableNamespace.
|
||||
|
||||
Args:
|
||||
dictionary (dict, optional): Dictionary to populate the namespace.
|
||||
**kwargs: Keyword arguments to populate the namespace.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if dictionary is not None:
|
||||
for key, value in dictionary.items():
|
||||
@@ -769,22 +1173,46 @@ class IterableNamespace(SimpleNamespace):
|
||||
setattr(self, key, value)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
"""Return an iterator over the keys of the namespace."""
|
||||
return iter(vars(self))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Allow bracket-style access to attributes.
|
||||
|
||||
Args:
|
||||
key (str): The name of the attribute.
|
||||
|
||||
Returns:
|
||||
Any: The value of the attribute.
|
||||
"""
|
||||
return vars(self)[key]
|
||||
|
||||
def items(self):
|
||||
"""Return a view of the namespace's (key, value) pairs."""
|
||||
return vars(self).items()
|
||||
|
||||
def values(self):
|
||||
"""Return a view of the namespace's values."""
|
||||
return vars(self).values()
|
||||
|
||||
def keys(self):
|
||||
"""Return a view of the namespace's keys."""
|
||||
return vars(self).keys()
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
"""Validate a single data frame against the dataset's feature specification.
|
||||
|
||||
Checks for missing/extra features, and validates the dtype and shape of each
|
||||
provided feature.
|
||||
|
||||
Args:
|
||||
frame (dict): The data frame to validate.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the frame does not match the feature specification.
|
||||
"""
|
||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||
actual_features = set(frame)
|
||||
|
||||
@@ -799,6 +1227,15 @@ def validate_frame(frame: dict, features: dict):
|
||||
|
||||
|
||||
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
|
||||
"""Check for missing or extra features in a frame.
|
||||
|
||||
Args:
|
||||
actual_features (set[str]): The set of feature names present in the frame.
|
||||
expected_features (set[str]): The set of feature names expected in the frame.
|
||||
|
||||
Returns:
|
||||
str: An error message string if there's a mismatch, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - expected_features
|
||||
@@ -814,6 +1251,19 @@ def validate_features_presence(actual_features: set[str], expected_features: set
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
"""Validate the dtype and shape of a single feature's value.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
feature (dict): The feature specification from the LeRobot features dictionary.
|
||||
value: The value of the feature to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the feature dtype is not supported for validation.
|
||||
"""
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
@@ -829,6 +1279,17 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
"""Validate a feature that is expected to be a numpy array.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_dtype (str): The expected numpy dtype as a string.
|
||||
expected_shape (list[int]): The expected shape.
|
||||
value (np.ndarray): The numpy array to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
@@ -846,6 +1307,18 @@ def validate_feature_numpy_array(
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
"""Validate a feature that is expected to be an image or video frame.
|
||||
|
||||
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
@@ -862,12 +1335,35 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
"""Validate a feature that is expected to be a string.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value (str): The value to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
Ensures the buffer has the required keys, contains at least one frame, and
|
||||
has features consistent with the dataset's specification.
|
||||
|
||||
Args:
|
||||
episode_buffer (dict): The buffer containing data for a single episode.
|
||||
total_episodes (int): The current total number of episodes in the dataset.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the buffer is invalid.
|
||||
NotImplementedError: If the episode index is manually set and doesn't match.
|
||||
"""
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
|
||||
@@ -127,9 +127,29 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
observation["task"] = env.call("task_description")
|
||||
task_result = env.call("task_description")
|
||||
|
||||
if isinstance(task_result, tuple):
|
||||
task_result = list(task_result)
|
||||
|
||||
if not isinstance(task_result, list):
|
||||
raise TypeError(f"Expected task_description to return a list, got {type(task_result)}")
|
||||
if not all(isinstance(item, str) for item in task_result):
|
||||
raise TypeError("All items in task_description result must be strings")
|
||||
|
||||
observation["task"] = task_result
|
||||
elif hasattr(env.envs[0], "task"):
|
||||
observation["task"] = env.call("task")
|
||||
task_result = env.call("task")
|
||||
|
||||
if isinstance(task_result, tuple):
|
||||
task_result = list(task_result)
|
||||
|
||||
if not isinstance(task_result, list):
|
||||
raise TypeError(f"Expected task to return a list, got {type(task_result)}")
|
||||
if not all(isinstance(item, str) for item in task_result):
|
||||
raise TypeError("All items in task result must be strings")
|
||||
|
||||
observation["task"] = task_result
|
||||
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
|
||||
num_envs = observation[list(observation.keys())[0]].shape[0]
|
||||
observation["task"] = ["" for _ in range(num_envs)]
|
||||
|
||||
@@ -15,16 +15,16 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@@ -33,38 +33,57 @@ def make_act_pre_post_processors(
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Creates the pre- and post-processing pipelines for the ACT policy.
|
||||
|
||||
The pre-processing pipeline handles normalization, batching, and device placement for the model inputs.
|
||||
The post-processing pipeline handles unnormalization and moves the model outputs back to the CPU.
|
||||
|
||||
Args:
|
||||
config (ACTConfig): The ACT policy configuration object.
|
||||
dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset
|
||||
statistics (e.g., mean and std) used for normalization. Defaults to None.
|
||||
preprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the
|
||||
preprocessor pipeline's constructor. Defaults to None.
|
||||
postprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the
|
||||
postprocessor pipeline's constructor. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: A tuple containing the
|
||||
pre-processor pipeline and the post-processor pipeline.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -16,16 +16,16 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,37 +34,63 @@ def make_diffusion_pre_post_processors(
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for a diffusion policy.
|
||||
|
||||
The pre-processing pipeline prepares the input data for the model by:
|
||||
1. Renaming features (if a `rename_map` is provided in `preprocessor_kwargs`).
|
||||
2. Normalizing the input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving the data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving the data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the diffusion policy,
|
||||
containing feature definitions, normalization mappings, and device information.
|
||||
dataset_stats: A dictionary of statistics used for normalization.
|
||||
Defaults to None.
|
||||
preprocessor_kwargs: Additional keyword arguments
|
||||
for the pre-processor pipeline. Defaults to an empty dictionary.
|
||||
postprocessor_kwargs: Additional keyword arguments
|
||||
for the post-processor pipeline. Defaults to an empty dictionary.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
@@ -38,11 +39,26 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor.pipeline import ProcessorKwargs, RobotProcessor
|
||||
from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
"""
|
||||
Retrieves a policy class by its registered name.
|
||||
|
||||
This function uses dynamic imports to avoid loading all policy classes into memory
|
||||
at once, improving startup time and reducing dependencies.
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy name is not recognized.
|
||||
"""
|
||||
if name == "tdmpc":
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
@@ -84,6 +100,24 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
"""
|
||||
Instantiates a policy configuration object based on the policy type.
|
||||
|
||||
This factory function simplifies the creation of policy configuration objects by
|
||||
mapping a string identifier to the corresponding config class.
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
|
||||
"reward_classifier".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of a `PreTrainedConfig` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `policy_type` is not recognized.
|
||||
"""
|
||||
if policy_type == "tdmpc":
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
@@ -107,7 +141,21 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the processor config."""
|
||||
"""
|
||||
A TypedDict defining the keyword arguments for processor configuration.
|
||||
|
||||
This provides type hints for the optional arguments passed to `make_pre_post_processors`,
|
||||
improving code clarity and enabling static analysis.
|
||||
|
||||
Attributes:
|
||||
preprocessor_config_filename: The filename for the preprocessor configuration.
|
||||
postprocessor_config_filename: The filename for the postprocessor configuration.
|
||||
preprocessor_overrides: A dictionary of overrides for the preprocessor configuration.
|
||||
postprocessor_overrides: A dictionary of overrides for the postprocessor configuration.
|
||||
dataset_stats: Dataset statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`.
|
||||
postprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`.
|
||||
"""
|
||||
|
||||
preprocessor_config_filename: str | None
|
||||
postprocessor_config_filename: str | None
|
||||
@@ -122,23 +170,28 @@ def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
"""Make a processor instance for a given policy type.
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Create or load pre- and post-processor pipelines for a given policy.
|
||||
|
||||
This function creates the appropriate processor configuration based on the policy type.
|
||||
Each policy type has its own processor with specific preprocessing steps.
|
||||
This function acts as a factory. It can either load existing processor pipelines
|
||||
from a pretrained path or create new ones from scratch based on the policy
|
||||
configuration. Each policy type has a dedicated factory function for its
|
||||
processors (e.g., `make_tdmpc_pre_post_processors`).
|
||||
|
||||
Args:
|
||||
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||
the processor from this path instead of creating a new one.
|
||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||
policy_cfg: The configuration of the policy for which to create processors.
|
||||
pretrained_path: An optional path to load pretrained processor pipelines from.
|
||||
If provided, pipelines are loaded from this path.
|
||||
**kwargs: Keyword arguments for processor configuration, as defined in
|
||||
`ProcessorConfigKwargs`.
|
||||
|
||||
Returns:
|
||||
Tuple of (input_processor, output_processor) for the policy.
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
NotImplementedError: If a processor factory is not implemented for the given
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Extract preprocessor and postprocessor kwargs
|
||||
@@ -146,16 +199,20 @@ def make_pre_post_processors(
|
||||
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
|
||||
|
||||
return (
|
||||
RobotProcessor.from_pretrained(
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
||||
config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=preprocessor_kwargs.get("to_transition"),
|
||||
to_output=preprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
RobotProcessor.from_pretrained(
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
||||
config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=postprocessor_kwargs.get("to_transition"),
|
||||
to_output=postprocessor_kwargs.get("to_output"),
|
||||
@@ -264,25 +321,29 @@ def make_policy(
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""Make an instance of a policy class.
|
||||
"""
|
||||
Instantiate a policy model.
|
||||
|
||||
This function exists because (for now) we need to parse features from either a dataset or an environment
|
||||
in order to properly dimension and instantiate a policy for that dataset or environment.
|
||||
This factory function handles the logic of creating a policy, which requires
|
||||
determining the input and output feature shapes. These shapes can be derived
|
||||
either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function
|
||||
can either initialize a new policy from scratch or load a pretrained one.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||
be loaded with the weights from that path.
|
||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||
provided if ds_meta is not. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||
cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is
|
||||
set, the policy will be loaded with weights from that path.
|
||||
ds_meta: Dataset metadata used to infer feature shapes and types. Also provides
|
||||
statistics for normalization layers.
|
||||
env_cfg: Environment configuration used to infer feature shapes and types.
|
||||
One of `ds_meta` or `env_cfg` must be provided.
|
||||
|
||||
Returns:
|
||||
PreTrainedPolicy: _description_
|
||||
An instantiated and device-placed policy model.
|
||||
|
||||
Raises:
|
||||
ValueError: If both or neither of `ds_meta` and `env_cfg` are provided.
|
||||
NotImplementedError: If attempting to use an unsupported policy-backend
|
||||
combination (e.g., VQBeT with 'mps').
|
||||
"""
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
|
||||
@@ -17,32 +17,45 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
ComplementaryDataProcessor,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.rename_processor import RenameProcessor
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessor):
|
||||
"""Add a new line to the end of the task if it doesn't have one.
|
||||
This is required for the PaliGemma tokenizer.
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Ensures that the task description string ends with a newline character.
|
||||
|
||||
This processing step is required for compatibility with the PaliGemma tokenizer,
|
||||
which expects a newline at the end of the text prompt. It handles both single
|
||||
strings and lists of strings for the 'task' key in complementary data.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""
|
||||
Adds a newline to the 'task' field if it doesn't already have one.
|
||||
|
||||
Args:
|
||||
complementary_data: A dictionary that may contain a 'task' key with a
|
||||
string or list of strings.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the modified 'task' field.
|
||||
"""
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
@@ -64,13 +77,51 @@ class Pi0NewLineProcessor(ComplementaryDataProcessor):
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
|
||||
Args:
|
||||
features: The input feature dictionary.
|
||||
|
||||
Returns:
|
||||
The unchanged feature dictionary.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi0_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
@@ -78,39 +129,39 @@ def make_pi0_pre_post_processors(
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
TokenizerProcessor(
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -16,55 +16,77 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_pi0fast_pre_post_processors(
|
||||
config: PI0Config,
|
||||
config: PI0FASTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0Fast policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,16 +17,16 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,37 +35,59 @@ def make_sac_pre_post_processors(
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SAC policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,11 +17,11 @@ import torch
|
||||
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
IdentityProcessor,
|
||||
NormalizerProcessor,
|
||||
DeviceProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RobotProcessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -30,30 +30,50 @@ def make_classifier_processor(
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the reward classifier.
|
||||
|
||||
The pre-processing pipeline prepares input data for the classifier by:
|
||||
1. Normalizing both input and output features based on dataset statistics.
|
||||
2. Moving the data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the classifier's output by:
|
||||
1. Moving the data to the CPU.
|
||||
2. Applying an identity step, as no unnormalization is needed for the output logits.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the RewardClassifier.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
NormalizerProcessor(
|
||||
NormalizerProcessorStep(
|
||||
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
NormalizerProcessor(
|
||||
NormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
|
||||
output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()]
|
||||
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name="classifier_preprocessor",
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name="classifier_postprocessor",
|
||||
**postprocessor_kwargs,
|
||||
|
||||
@@ -16,21 +16,20 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
)
|
||||
from lerobot.processor.pipeline import (
|
||||
ComplementaryDataProcessor,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@@ -39,52 +38,82 @@ def make_smolvla_pre_post_processors(
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SmolVLA policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Ensuring the language task description ends with a newline character.
|
||||
5. Tokenizing the language task description.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output actions to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SmolVLA policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
SmolVLANewLineProcessor(),
|
||||
TokenizerProcessor(
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
padding=config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=config.tokenizer_max_length,
|
||||
),
|
||||
DeviceProcessor(device=config.device),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ComplementaryDataProcessor):
|
||||
"""Add a new line to the end of the task if it doesn't have one."""
|
||||
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
A processor step that ensures the 'task' description ends with a newline character.
|
||||
|
||||
This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
|
||||
newline at the end of the prompt. It handles both single string tasks and lists
|
||||
of string tasks.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
@@ -107,3 +136,8 @@ class SmolVLANewLineProcessor(ComplementaryDataProcessor):
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -16,16 +16,16 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,37 +34,59 @@ def make_tdmpc_pre_post_processors(
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the TDMPC policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the TDMPC policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,16 +17,16 @@
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,37 +35,59 @@ def make_vqbet_pre_post_processors(
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the VQ-BeT policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features, allowing customization to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the VQ-BeT policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
|
||||
NormalizerProcessor(
|
||||
RenameObservationsProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -14,74 +14,90 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .batch_processor import ToBatchProcessor
|
||||
from .delta_action_processor import MapDeltaActionToRobotAction, MapTensorToDeltaActionDict
|
||||
from .device_processor import DeviceProcessor
|
||||
from .gym_action_processor import Numpy2TorchActionProcessor, Torch2NumpyActionProcessor
|
||||
from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryData,
|
||||
AddTeleopEventsAsInfo,
|
||||
GripperPenaltyProcessor,
|
||||
ImageCropResizeProcessor,
|
||||
InterventionActionProcessor,
|
||||
RewardClassifierProcessor,
|
||||
TimeLimitProcessor,
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
merge_transitions,
|
||||
transition_to_batch,
|
||||
transition_to_dataset_frame,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessor, MotorCurrentProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .gym_action_processor import Numpy2TorchActionProcessorStep, Torch2NumpyActionProcessorStep
|
||||
from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessorStep
|
||||
from .pipeline import (
|
||||
ActionProcessor,
|
||||
DoneProcessor,
|
||||
EnvTransition,
|
||||
IdentityProcessor,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DoneProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
InfoProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
RewardProcessorStep,
|
||||
RobotProcessorPipeline,
|
||||
TruncatedProcessorStep,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
from .tokenizer_processor import TokenizerProcessor
|
||||
from .rename_processor import RenameObservationsProcessorStep
|
||||
from .tokenizer_processor import TokenizerProcessorStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessor",
|
||||
"AddTeleopActionAsComplimentaryData",
|
||||
"AddTeleopEventsAsInfo",
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"MapDeltaActionToRobotAction",
|
||||
"MapTensorToDeltaActionDict",
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
"AddTeleopEventsAsInfoStep",
|
||||
"ComplementaryDataProcessorStep",
|
||||
"batch_to_transition",
|
||||
"create_transition",
|
||||
"DeviceProcessorStep",
|
||||
"DoneProcessorStep",
|
||||
"EnvTransition",
|
||||
"GripperPenaltyProcessor",
|
||||
"IdentityProcessor",
|
||||
"ImageCropResizeProcessor",
|
||||
"InfoProcessor",
|
||||
"InterventionActionProcessor",
|
||||
"JointVelocityProcessor",
|
||||
"MapDeltaActionToRobotAction",
|
||||
"MotorCurrentProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"ObservationProcessor",
|
||||
"IdentityProcessorStep",
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"JointVelocityProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"merge_transitions",
|
||||
"MotorCurrentProcessorStep",
|
||||
"NormalizerProcessorStep",
|
||||
"Numpy2TorchActionProcessorStep",
|
||||
"ObservationProcessorStep",
|
||||
"PolicyProcessorPipeline",
|
||||
"ProcessorKwargs",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessor",
|
||||
"RewardClassifierProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"ToBatchProcessor",
|
||||
"TokenizerProcessor",
|
||||
"TimeLimitProcessor",
|
||||
"Numpy2TorchActionProcessor",
|
||||
"Torch2NumpyActionProcessor",
|
||||
"RenameObservationsProcessorStep",
|
||||
"RewardClassifierProcessorStep",
|
||||
"RewardProcessorStep",
|
||||
"DataProcessorPipeline",
|
||||
"TimeLimitProcessorStep",
|
||||
"AddBatchDimensionProcessorStep",
|
||||
"RobotProcessorPipeline",
|
||||
"TokenizerProcessorStep",
|
||||
"Torch2NumpyActionProcessorStep",
|
||||
"transition_to_batch",
|
||||
"transition_to_dataset_frame",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
"TruncatedProcessorStep",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -11,16 +13,25 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script defines processor steps for adding a batch dimension to various components of an environment transition.
|
||||
|
||||
These steps are designed to process actions, observations, and complementary data, making them suitable for batch processing by adding a leading dimension. This is a common requirement before feeding data into a neural network model.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
|
||||
from .core import EnvTransition
|
||||
from .pipeline import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
@@ -28,22 +39,66 @@ from lerobot.processor.pipeline import (
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
||||
class ToBatchProcessorAction(ActionProcessor):
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
class AddBatchDimensionActionStep(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to add a batch dimension to a 1D tensor action.
|
||||
|
||||
def action(self, action):
|
||||
This is useful for creating a batch of size 1 from a single action sample.
|
||||
"""
|
||||
|
||||
def action(self, action: Tensor) -> Tensor:
|
||||
"""
|
||||
Adds a batch dimension to the action if it's a 1D tensor.
|
||||
|
||||
Args:
|
||||
action: The action tensor.
|
||||
|
||||
Returns:
|
||||
The action tensor with an added batch dimension.
|
||||
"""
|
||||
if not isinstance(action, Tensor) or action.dim() != 1:
|
||||
return action
|
||||
|
||||
return action.unsqueeze(0)
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Adding a batch dimension does not alter the feature definition.
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||
class ToBatchProcessorObservation(ObservationProcessor):
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processor step to add a batch dimension to observations.
|
||||
|
||||
def observation(self, observation):
|
||||
It handles different types of observations:
|
||||
- State vectors (1D tensors).
|
||||
- Single images (3D tensors).
|
||||
- Dictionaries of multiple images (3D tensors).
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Adds a batch dimension to tensor-based observations in the observation dictionary.
|
||||
|
||||
Args:
|
||||
observation: The observation dictionary.
|
||||
|
||||
Returns:
|
||||
The observation dictionary with batch dimensions added to tensors.
|
||||
"""
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
@@ -63,13 +118,44 @@ class ToBatchProcessorObservation(ObservationProcessor):
|
||||
observation[key] = value.unsqueeze(0)
|
||||
return observation
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Adding a batch dimension does not alter the feature definition.
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||
class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Processor step to add a batch dimension to complementary data fields.
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
Handles specific keys like 'task', 'index', and 'task_index' to make them batched.
|
||||
- 'task' (str) is wrapped in a list.
|
||||
- 'index' and 'task_index' (0D tensors) get a batch dimension.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Adds a batch dimension to specific fields in the complementary data dictionary.
|
||||
|
||||
Args:
|
||||
complementary_data: The complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
The complementary data dictionary with batch dimensions added.
|
||||
"""
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
@@ -89,54 +175,76 @@ class ToBatchProcessorComplementaryData(ComplementaryDataProcessor):
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Adding a batch dimension does not alter the feature definition.
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class ToBatchProcessor(ProcessorStep):
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
class AddBatchDimensionProcessorStep(ProcessorStep):
|
||||
"""
|
||||
A composite processor step that adds a batch dimension to the entire environment transition.
|
||||
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
This step combines individual processors for actions, observations, and complementary data
|
||||
to create a batched transition (batch size 1) from a single-instance transition.
|
||||
|
||||
- For state observations (observation.state, observation.environment_state):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For image observations (observation.image, observation.images.*):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
||||
|
||||
- For actions:
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For task field in complementary data:
|
||||
Wraps string task in a list to add batch dimension
|
||||
(task must be a string or list of strings)
|
||||
|
||||
This is useful when processing single transitions that need to be batched for
|
||||
model inference or when converting from unbatched environment outputs to
|
||||
batched model inputs.
|
||||
|
||||
The processor only modifies tensors that need batching and leaves already
|
||||
batched tensors unchanged.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# State: (7,) -> (1, 7)
|
||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||
# Action: (4,) -> (1, 4)
|
||||
# Task: "pick_cube" -> ["pick_cube"]
|
||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||
```
|
||||
Attributes:
|
||||
to_batch_action_processor: Processor for the action component.
|
||||
to_batch_observation_processor: Processor for the observation component.
|
||||
to_batch_complementary_data_processor: Processor for the complementary data component.
|
||||
"""
|
||||
|
||||
to_batch_action_processor: ToBatchProcessorAction = field(default_factory=ToBatchProcessorAction)
|
||||
to_batch_observation_processor: ToBatchProcessorObservation = field(
|
||||
default_factory=ToBatchProcessorObservation
|
||||
to_batch_action_processor: AddBatchDimensionActionStep = field(
|
||||
default_factory=AddBatchDimensionActionStep
|
||||
)
|
||||
to_batch_complementary_data_processor: ToBatchProcessorComplementaryData = field(
|
||||
default_factory=ToBatchProcessorComplementaryData
|
||||
to_batch_observation_processor: AddBatchDimensionObservationStep = field(
|
||||
default_factory=AddBatchDimensionObservationStep
|
||||
)
|
||||
to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field(
|
||||
default_factory=AddBatchDimensionComplementaryDataStep
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Applies the batching process to all relevant parts of an environment transition.
|
||||
|
||||
Args:
|
||||
transition: The environment transition to process.
|
||||
|
||||
Returns:
|
||||
The environment transition with a batch dimension added.
|
||||
"""
|
||||
transition = self.to_batch_action_processor(transition)
|
||||
transition = self.to_batch_observation_processor(transition)
|
||||
transition = self.to_batch_complementary_data_processor(transition)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Adding a batch dimension does not alter the feature definition.
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
# NOTE: We ignore the batch dimension when transforming features
|
||||
return features
|
||||
|
||||
+331
-134
@@ -16,18 +16,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||
|
||||
from .pipeline import EnvTransition, TransitionKey
|
||||
from .core import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@singledispatch
|
||||
@@ -44,12 +43,12 @@ def to_tensor(
|
||||
different input types appropriately.
|
||||
|
||||
Args:
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.)
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.).
|
||||
dtype: Target tensor dtype. If None, preserves original dtype.
|
||||
device: Target device for the tensor.
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
A PyTorch tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input type is not supported.
|
||||
@@ -59,7 +58,7 @@ def to_tensor(
|
||||
|
||||
@to_tensor.register(torch.Tensor)
|
||||
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle existing PyTorch tensors."""
|
||||
"""Handle conversion for existing PyTorch tensors."""
|
||||
if dtype is not None:
|
||||
value = value.to(dtype=dtype)
|
||||
if device is not None:
|
||||
@@ -75,17 +74,17 @@ def _(
|
||||
device=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Handle numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
|
||||
"""Handle conversion for numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars.
|
||||
if value.ndim == 0:
|
||||
# Numpy scalars should be converted to 0-dimensional tensors
|
||||
# Numpy scalars should be converted to 0-dimensional tensors.
|
||||
scalar_value = value.item()
|
||||
return torch.tensor(scalar_value, dtype=dtype, device=device)
|
||||
|
||||
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
|
||||
# Create tensor from numpy array.
|
||||
tensor = torch.from_numpy(value)
|
||||
|
||||
# Apply dtype conversion if specified
|
||||
# Apply dtype and device conversion if specified.
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if device is not None:
|
||||
@@ -99,20 +98,20 @@ def _(
|
||||
@to_tensor.register(np.integer)
|
||||
@to_tensor.register(np.floating)
|
||||
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle scalar values including numpy scalars."""
|
||||
"""Handle conversion for scalar values including numpy scalars."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(list)
|
||||
@to_tensor.register(tuple)
|
||||
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle sequences (lists, tuples)."""
|
||||
"""Handle conversion for sequences (lists, tuples)."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(dict)
|
||||
def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
"""Handle dictionaries by recursively converting values to tensors."""
|
||||
"""Handle conversion for dictionaries by recursively converting their values to tensors."""
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
@@ -122,7 +121,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
continue
|
||||
|
||||
if isinstance(sub_value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
# Recursively process nested dictionaries.
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
@@ -130,7 +129,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert individual values to tensors
|
||||
# Convert individual values to tensors.
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
@@ -139,17 +138,46 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
return result
|
||||
|
||||
|
||||
def _from_tensor(x: Any):
|
||||
def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
|
||||
"""
|
||||
Convert a PyTorch tensor to a numpy array or scalar if applicable.
|
||||
|
||||
If the input is not a tensor, it is returned unchanged.
|
||||
|
||||
Args:
|
||||
x: The input, which can be a tensor or any other type.
|
||||
|
||||
Returns:
|
||||
A numpy array, a scalar, or the original input.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
|
||||
def _is_image(arr: Any) -> bool:
|
||||
"""
|
||||
Check if a given array is likely an image (uint8, 3D).
|
||||
|
||||
Args:
|
||||
arr: The array to check.
|
||||
|
||||
Returns:
|
||||
True if the array matches the image criteria, False otherwise.
|
||||
"""
|
||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
||||
|
||||
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""
|
||||
Separate an observation dictionary into state and image components.
|
||||
|
||||
Args:
|
||||
obs: The observation dictionary.
|
||||
|
||||
Returns:
|
||||
A tuple containing two dictionaries: one for state and one for images.
|
||||
"""
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if "image" in k.lower() or _is_image(v):
|
||||
@@ -159,168 +187,337 @@ def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any],
|
||||
return state, images
|
||||
|
||||
|
||||
def make_obs_act_transition(
|
||||
*, obs: dict[str, Any] | None = None, act: dict[str, Any] | None = None
|
||||
) -> EnvTransition:
|
||||
return {
|
||||
TransitionKey.OBSERVATION: {} if obs is None else obs,
|
||||
TransitionKey.ACTION: {} if act is None else act,
|
||||
TransitionKey.INFO: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
TransitionKey.REWARD: None,
|
||||
TransitionKey.DONE: None,
|
||||
TransitionKey.TRUNCATED: None,
|
||||
}
|
||||
# Private Helper Functions (Common Logic)
|
||||
|
||||
|
||||
def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
|
||||
Extract complementary data from a batch dictionary.
|
||||
|
||||
This includes padding flags, task description, and indices.
|
||||
|
||||
Args:
|
||||
batch: The batch dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the extracted complementary data.
|
||||
"""
|
||||
act_dict: dict[str, Any] = {}
|
||||
for k, v in action.items():
|
||||
# Check if the value is a type that should not be converted to a tensor.
|
||||
if isinstance(v, (Rotation, dict)):
|
||||
act_dict[f"{ACTION}.{k}"] = v
|
||||
continue
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
|
||||
|
||||
return make_obs_act_transition(act=act_dict)
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
|
||||
|
||||
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
|
||||
def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransition:
|
||||
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
|
||||
Merge two transitions, with the second one taking precedence in case of conflicts.
|
||||
|
||||
Args:
|
||||
base: The base transition.
|
||||
other: The transition to merge, which will overwrite base values.
|
||||
|
||||
Returns:
|
||||
The merged transition dictionary.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
out = deepcopy(base)
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
for k, v in state.items():
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
|
||||
|
||||
for cam, img in images.items():
|
||||
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
|
||||
|
||||
return make_obs_act_transition(obs=obs_dict)
|
||||
|
||||
|
||||
def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
action_dict = transition.get(TransitionKey.ACTION) or {}
|
||||
|
||||
if action_dict is None:
|
||||
return out
|
||||
|
||||
for k, v in action_dict.items():
|
||||
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
|
||||
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
|
||||
out[out_key] = float(v)
|
||||
for key in (
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.INFO,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
):
|
||||
if other.get(key):
|
||||
out.setdefault(key, {}).update(deepcopy(other[key]))
|
||||
|
||||
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
|
||||
if k in other:
|
||||
out[k] = other[k]
|
||||
return out
|
||||
|
||||
|
||||
def to_dataset_frame(
|
||||
transitions_or_transition: EnvTransition | Iterable[EnvTransition], features: dict[str, dict]
|
||||
) -> dict[str, any]:
|
||||
# Core Conversion Functions
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
reward: float = 0.0,
|
||||
done: bool = False,
|
||||
truncated: bool = False,
|
||||
info: dict[str, Any] | None = None,
|
||||
complementary_data: dict[str, Any] | None = None,
|
||||
) -> EnvTransition:
|
||||
"""
|
||||
Converts a single EnvTransition or an iterable of them into a flat,
|
||||
dataset-friendly dictionary for training or evaluation, according to
|
||||
the provided `features` spec.
|
||||
Create an `EnvTransition` dictionary with sensible defaults.
|
||||
|
||||
Args:
|
||||
transitions_or_transition: Either a single EnvTransition dict
|
||||
or an iterable of them (which will be merged).
|
||||
features (dict[str, dict]):
|
||||
A feature specification dictionary:
|
||||
- 'action': dict with 'names': list of action feature names
|
||||
- 'observation.state': dict with 'names': list of state feature names
|
||||
- keys starting with 'observation.images.' are passed through
|
||||
observation: Observation dictionary.
|
||||
action: Action dictionary.
|
||||
reward: Scalar reward value.
|
||||
done: Episode termination flag.
|
||||
truncated: Episode truncation flag.
|
||||
info: Additional info dictionary.
|
||||
complementary_data: Complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
batch (dict[str, any]): Flat dictionary containing:
|
||||
- numpy arrays for "observation.state" and "action"
|
||||
- any image tensors defined in features
|
||||
- next.{reward,done,truncated}
|
||||
- info dict
|
||||
- *_is_pad flags and task from complementary_data
|
||||
A complete `EnvTransition` dictionary.
|
||||
"""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info if info is not None else {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def action_to_transition(action: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw action dictionary into a standardized `EnvTransition`.
|
||||
|
||||
The keys in the action dictionary are prefixed with "action." and stored under
|
||||
the `ACTION` key in the transition. Values are converted to tensors, except for
|
||||
special types like `Rotation`.
|
||||
|
||||
Args:
|
||||
action: The raw action dictionary from a teleoperation device or controller.
|
||||
|
||||
Returns:
|
||||
An `EnvTransition` containing the formatted action.
|
||||
"""
|
||||
|
||||
return create_transition(observation={}, action=action)
|
||||
|
||||
|
||||
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw robot observation dictionary into a standardized `EnvTransition`.
|
||||
|
||||
The observation is split into state and image components. State keys are prefixed
|
||||
with "observation.state." and image keys with "observation.images.". The result is
|
||||
stored under the `OBSERVATION` key in the transition.
|
||||
|
||||
Args:
|
||||
observation: The raw observation dictionary from the environment.
|
||||
|
||||
Returns:
|
||||
An `EnvTransition` containing the formatted observation.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
|
||||
image_observations = {f"{OBS_IMAGES}.{cam}": img for cam, img in images.items()}
|
||||
|
||||
return create_transition(observation={**state, **image_observations}, action={})
|
||||
|
||||
|
||||
def transition_to_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Extract a raw action dictionary for a robot from an `EnvTransition`.
|
||||
|
||||
This function searches for keys in the format "action.*.pos" or "action.*.vel"
|
||||
and converts them into a flat dictionary suitable for sending to a robot controller.
|
||||
|
||||
Args:
|
||||
transition: The `EnvTransition` containing the action.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the raw robot action.
|
||||
"""
|
||||
return transition.get(TransitionKey.ACTION)
|
||||
|
||||
|
||||
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Merge a sequence of transitions into a single one.
|
||||
|
||||
If a single transition is provided, it is returned as is. For a sequence,
|
||||
transitions are merged sequentially, with later transitions in the sequence
|
||||
overwriting earlier ones.
|
||||
|
||||
Args:
|
||||
transitions: A single transition or a sequence of them.
|
||||
|
||||
Returns:
|
||||
A single merged `EnvTransition`.
|
||||
|
||||
Raises:
|
||||
ValueError: If an empty sequence of transitions is provided.
|
||||
"""
|
||||
|
||||
if not isinstance(transitions, Sequence): # Single transition
|
||||
return transitions
|
||||
|
||||
items = list(transitions)
|
||||
if not items:
|
||||
raise ValueError("merge_transitions() requires a non-empty sequence of transitions")
|
||||
|
||||
result = items[0]
|
||||
for t in items[1:]:
|
||||
result = _merge_transitions(result, t)
|
||||
return result
|
||||
|
||||
|
||||
def transition_to_dataset_frame(
|
||||
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert one or more transitions into a flat dictionary suitable for a dataset frame.
|
||||
|
||||
This function processes `EnvTransition` objects according to a feature
|
||||
specification, producing a format ready for training or evaluation.
|
||||
|
||||
Args:
|
||||
transitions_or_transition: A single `EnvTransition` or a sequence to be merged.
|
||||
features: A feature specification dictionary.
|
||||
|
||||
Returns:
|
||||
A flat dictionary representing a single frame of data for a dataset.
|
||||
"""
|
||||
action_names = features.get(ACTION, {}).get("names", [])
|
||||
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
|
||||
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
|
||||
|
||||
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
out = deepcopy(base)
|
||||
for key in (
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.INFO,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
):
|
||||
if other.get(key):
|
||||
out.setdefault(key, {}).update(deepcopy(other[key]))
|
||||
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
|
||||
if k in other:
|
||||
out[k] = other[k]
|
||||
return out
|
||||
|
||||
def _ensure_transition(obj) -> EnvTransition:
|
||||
# single transition
|
||||
if isinstance(obj, dict) and any(isinstance(k, TransitionKey) for k in obj):
|
||||
return obj
|
||||
# iterable of transitions
|
||||
if isinstance(obj, Iterable):
|
||||
items = list(obj)
|
||||
if not items:
|
||||
return {}
|
||||
acc = items[0]
|
||||
for t in items[1:]:
|
||||
acc = _merge(acc, t)
|
||||
return acc
|
||||
raise TypeError("Expected EnvTransition or iterable of them")
|
||||
|
||||
tr = _ensure_transition(transitions_or_transition)
|
||||
tr = merge_transitions(transitions_or_transition)
|
||||
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
act = tr.get(TransitionKey.ACTION, {}) or {}
|
||||
batch: dict[str, any] = {}
|
||||
batch: dict[str, Any] = {}
|
||||
|
||||
# Images passthrough
|
||||
# Passthrough for images.
|
||||
for k in image_keys:
|
||||
if k in obs:
|
||||
batch[k] = obs[k]
|
||||
|
||||
# Observation.state vector
|
||||
# Create observation.state vector.
|
||||
if obs_state_names:
|
||||
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
|
||||
vals = [from_tensor_to_numpy(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
|
||||
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Action vector
|
||||
# Create action vector.
|
||||
if action_names:
|
||||
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
|
||||
vals = [from_tensor_to_numpy(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
|
||||
batch[ACTION] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Add transition metadata.
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD])
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
batch[DONE] = _from_tensor(tr[TransitionKey.DONE])
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
reward_val = from_tensor_to_numpy(tr[TransitionKey.REWARD])
|
||||
# Check if features expect array format, otherwise keep as scalar.
|
||||
if REWARD in features and features[REWARD].get("shape") == (1,):
|
||||
batch[REWARD] = np.array([reward_val], dtype=np.float32)
|
||||
else:
|
||||
batch[REWARD] = reward_val
|
||||
|
||||
# Complementary data flags and task
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
done_val = from_tensor_to_numpy(tr[TransitionKey.DONE])
|
||||
if DONE in features and features[DONE].get("shape") == (1,):
|
||||
batch[DONE] = np.array([done_val], dtype=bool)
|
||||
else:
|
||||
batch[DONE] = done_val
|
||||
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
truncated_val = from_tensor_to_numpy(tr[TransitionKey.TRUNCATED])
|
||||
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
|
||||
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
|
||||
else:
|
||||
batch[TRUNCATED] = truncated_val
|
||||
|
||||
# Add complementary data flags and task.
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if comp:
|
||||
# pad flags
|
||||
# Padding flags.
|
||||
for k, v in comp.items():
|
||||
if k.endswith("_is_pad"):
|
||||
batch[k] = v
|
||||
# task label
|
||||
# Task label.
|
||||
if comp.get("task") is not None:
|
||||
batch["task"] = comp["task"]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a batch dictionary from a dataset/dataloader into an `EnvTransition`.
|
||||
|
||||
This function maps recognized keys from a batch to the `EnvTransition` structure,
|
||||
filling in missing keys with sensible defaults.
|
||||
|
||||
Args:
|
||||
batch: A batch dictionary.
|
||||
|
||||
Returns:
|
||||
An `EnvTransition` dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is not a dictionary.
|
||||
"""
|
||||
|
||||
# Validate input type.
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
# Extract observation and complementary data keys.
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
complementary_data = _extract_complementary_data(batch)
|
||||
|
||||
return create_transition(
|
||||
observation=observation_keys if observation_keys else None,
|
||||
action=batch.get("action"),
|
||||
reward=batch.get("next.reward", 0.0),
|
||||
done=batch.get("next.done", False),
|
||||
truncated=batch.get("next.truncated", False),
|
||||
info=batch.get("info", {}),
|
||||
complementary_data=complementary_data if complementary_data else None,
|
||||
)
|
||||
|
||||
|
||||
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Convert an `EnvTransition` back to the canonical batch format used in LeRobot.
|
||||
|
||||
This is the inverse of `batch_to_transition`.
|
||||
|
||||
Args:
|
||||
transition: The `EnvTransition` to convert.
|
||||
|
||||
Returns:
|
||||
A batch dictionary with canonical LeRobot field names.
|
||||
"""
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add complementary data.
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if comp_data:
|
||||
batch.update(comp_data)
|
||||
|
||||
# Flatten observation dictionary.
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def identity_transition(tr: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
An identity function for transitions, returning the input unchanged.
|
||||
|
||||
Useful as a default or placeholder in processing pipelines.
|
||||
|
||||
Args:
|
||||
tr: An `EnvTransition`.
|
||||
|
||||
Returns:
|
||||
The same `EnvTransition`.
|
||||
"""
|
||||
return tr
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
# TODO(Steven): Use consts
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
DONE = "done"
|
||||
TRUNCATED = "truncated"
|
||||
INFO = "info"
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
"EnvTransition",
|
||||
{
|
||||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||
TransitionKey.INFO.value: dict[str, Any] | None,
|
||||
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||||
},
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
# !/usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -18,60 +18,70 @@ from dataclasses import dataclass
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@dataclass
|
||||
class MapTensorToDeltaActionDict(ActionProcessor):
|
||||
class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"""
|
||||
Map a tensor to a delta action dictionary.
|
||||
Maps a flat action tensor from a policy to a structured delta action dictionary.
|
||||
|
||||
This step is typically used after a policy outputs a continuous action vector.
|
||||
It decomposes the vector into named components for delta movements of the
|
||||
end-effector (x, y, z) and optionally the gripper.
|
||||
|
||||
Attributes:
|
||||
use_gripper: If True, assumes the 4th element of the tensor is the
|
||||
gripper action.
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
|
||||
def action(self, action: Tensor) -> dict:
|
||||
if isinstance(action, dict):
|
||||
return action
|
||||
if action.dim() > 1:
|
||||
action = action.squeeze(0)
|
||||
|
||||
# TODO (maractingi): add rotation
|
||||
delta_action = {
|
||||
"action.delta_x": action[0],
|
||||
"action.delta_y": action[1],
|
||||
"action.delta_z": action[2],
|
||||
"delta_x": action[0],
|
||||
"delta_y": action[1],
|
||||
"delta_z": action[2],
|
||||
}
|
||||
if action.shape[0] > 3:
|
||||
delta_action["action.gripper"] = action[3]
|
||||
if self.use_gripper:
|
||||
delta_action["gripper"] = action[3]
|
||||
return delta_action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION]["delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
if self.use_gripper:
|
||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
|
||||
"""
|
||||
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
|
||||
for use with inverse kinematics processors.
|
||||
Maps delta actions from teleoperators to robot target actions for inverse kinematics.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.delta_x": float,
|
||||
"action.delta_y": float,
|
||||
"action.delta_z": float,
|
||||
"action.gripper": float (optional),
|
||||
}
|
||||
This step converts a dictionary of delta movements (e.g., from a gamepad)
|
||||
into a target action format that includes an "enabled" flag and target
|
||||
end-effector positions. It also handles scaling and noise filtering.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.target_x": float,
|
||||
"action.target_y": float,
|
||||
"action.target_z": float,
|
||||
"action.target_wx": float,
|
||||
"action.target_wy": float,
|
||||
"action.target_wz": float,
|
||||
"action.gripper": float,
|
||||
}
|
||||
Attributes:
|
||||
position_scale: A factor to scale the delta position inputs.
|
||||
rotation_scale: A factor to scale the delta rotation inputs (currently unused).
|
||||
noise_threshold: The magnitude below which delta inputs are considered noise
|
||||
and do not trigger an "enabled" state.
|
||||
"""
|
||||
|
||||
# Scale factors for delta movements
|
||||
@@ -82,10 +92,10 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
def action(self, action: dict) -> dict:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
delta_x = action.pop("action.delta_x", 0.0)
|
||||
delta_y = action.pop("action.delta_y", 0.0)
|
||||
delta_z = action.pop("action.delta_z", 0.0)
|
||||
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
|
||||
delta_x = action.pop("delta_x", 0.0)
|
||||
delta_y = action.pop("delta_y", 0.0)
|
||||
delta_z = action.pop("delta_z", 0.0)
|
||||
gripper = action.pop("gripper", 1.0) # Default to "stay" (1.0)
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
@@ -105,31 +115,33 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": scaled_delta_x,
|
||||
"action.target_y": scaled_delta_y,
|
||||
"action.target_z": scaled_delta_z,
|
||||
"action.target_wx": target_wx,
|
||||
"action.target_wy": target_wy,
|
||||
"action.target_wz": target_wz,
|
||||
"action.gripper": float(gripper),
|
||||
"enabled": enabled,
|
||||
"target_x": scaled_delta_x,
|
||||
"target_y": scaled_delta_y,
|
||||
"target_z": scaled_delta_z,
|
||||
"target_wx": target_wx,
|
||||
"target_wy": target_wy,
|
||||
"target_wz": target_wz,
|
||||
"gripper": float(gripper),
|
||||
}
|
||||
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Transform features to match output format."""
|
||||
# Update features to reflect the new action format
|
||||
features.update(
|
||||
{
|
||||
"action.enabled": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_x": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_y": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_z": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wx": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wy": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wz": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.gripper": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
}
|
||||
)
|
||||
features[PipelineFeatureType.ACTION].pop("delta_x", None)
|
||||
features[PipelineFeatureType.ACTION].pop("delta_y", None)
|
||||
features[PipelineFeatureType.ACTION].pop("delta_z", None)
|
||||
features[PipelineFeatureType.ACTION].pop("gripper", None)
|
||||
|
||||
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
@@ -13,24 +13,37 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script defines a processor step for moving environment transition data to a specific torch device and casting
|
||||
its floating-point precision.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("device_processor")
|
||||
@dataclass
|
||||
class DeviceProcessor(ProcessorStep):
|
||||
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
|
||||
class DeviceProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to move all tensors within an `EnvTransition` to a specified device and optionally cast their
|
||||
floating-point data type.
|
||||
|
||||
This processor ensures that all tensors in the transition are moved to the
|
||||
specified device (CPU or GPU) before they are returned. It can also convert
|
||||
floating-point tensors to a specified dtype while preserving non-float types
|
||||
(int, long, bool, etc.).
|
||||
This is crucial for preparing data for model training or inference on hardware like GPUs.
|
||||
|
||||
Attributes:
|
||||
device: The target device for tensors (e.g., "cpu", "cuda", "cuda:0").
|
||||
float_dtype: The target floating-point dtype as a string (e.g., "float32", "float16", "bfloat16").
|
||||
If None, the dtype is not changed.
|
||||
"""
|
||||
|
||||
device: str = "cpu"
|
||||
@@ -47,8 +60,15 @@ class DeviceProcessor(ProcessorStep):
|
||||
}
|
||||
|
||||
def __post_init__(self):
|
||||
self._device: torch.device = get_safe_torch_device(self.device)
|
||||
self.device = self._device.type # cuda might have changed to cuda:1
|
||||
"""
|
||||
Initializes the processor by converting string configurations to torch objects.
|
||||
|
||||
This method sets up the `torch.device`, determines if transfers can be non-blocking, and validates the
|
||||
`float_dtype` string, converting it to a `torch.dtype` object.
|
||||
"""
|
||||
self.tensor_device: torch.device = get_safe_torch_device(self.device)
|
||||
# Update device string in case a specific GPU was selected (e.g., "cuda" -> "cuda:0")
|
||||
self.device = self.tensor_device.type
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Validate and convert float_dtype string to torch dtype
|
||||
@@ -57,28 +77,33 @@ class DeviceProcessor(ProcessorStep):
|
||||
raise ValueError(
|
||||
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
|
||||
)
|
||||
|
||||
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
|
||||
else:
|
||||
self._target_float_dtype = None
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Process a tensor by moving to device and optionally converting float dtype.
|
||||
"""
|
||||
Moves a single tensor to the target device and casts its dtype.
|
||||
|
||||
If the tensor is already on a GPU and we're configured for a GPU, it preserves
|
||||
that GPU placement (useful for multi-GPU training with Accelerate).
|
||||
Otherwise, it moves to the configured device.
|
||||
Handles multi-GPU scenarios by not moving a tensor if it's already on a different CUDA device than
|
||||
the target, which is useful when using frameworks like Accelerate.
|
||||
|
||||
Args:
|
||||
tensor: The input torch.Tensor.
|
||||
|
||||
Returns:
|
||||
The processed tensor on the correct device and with the correct dtype.
|
||||
"""
|
||||
# Determine target device
|
||||
if tensor.is_cuda and self._device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement
|
||||
if tensor.is_cuda and self.tensor_device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement.
|
||||
# This handles multi-GPU scenarios where Accelerate has already placed
|
||||
# tensors on the correct GPU for each process
|
||||
# tensors on the correct GPU for each process.
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# Either tensor is on CPU, or we're configured for CPU
|
||||
# In both cases, use the configured device
|
||||
target_device = self._device
|
||||
# Either tensor is on CPU, or we're configured for CPU.
|
||||
# In both cases, use the configured device.
|
||||
target_device = self.tensor_device
|
||||
|
||||
# Only move if necessary
|
||||
if tensor.device != target_device:
|
||||
@@ -91,6 +116,18 @@ class DeviceProcessor(ProcessorStep):
|
||||
return tensor
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Applies device and dtype conversion to all tensors in an environment transition.
|
||||
|
||||
It iterates through the transition, finds all `torch.Tensor` objects (including those nested in
|
||||
dictionaries like `observation`), and processes them.
|
||||
|
||||
Args:
|
||||
transition: The input `EnvTransition` object.
|
||||
|
||||
Returns:
|
||||
A new `EnvTransition` object with all tensors moved to the target device and dtype.
|
||||
"""
|
||||
new_transition = transition.copy()
|
||||
|
||||
simple_tensor_keys = [
|
||||
@@ -105,13 +142,13 @@ class DeviceProcessor(ProcessorStep):
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
]
|
||||
|
||||
# Process simple tensors
|
||||
# Process simple, top-level tensors
|
||||
for key in simple_tensor_keys:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_transition[key] = self._process_tensor(value)
|
||||
|
||||
# Process dictionary-like tensors
|
||||
# Process tensors nested within dictionaries
|
||||
for key in dict_tensor_keys:
|
||||
data_dict = transition.get(key)
|
||||
if data_dict is not None:
|
||||
@@ -124,5 +161,26 @@ class DeviceProcessor(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the device and float_dtype settings.
|
||||
"""
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Device and dtype transformations do not alter the fundamental definition of the features (e.g., shape).
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#! /usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -10,20 +10,35 @@
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessor(ActionProcessor):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
class Torch2NumpyActionProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
Converts a PyTorch tensor action to a NumPy array.
|
||||
|
||||
This step is useful when the output of a policy (typically a torch.Tensor)
|
||||
needs to be passed to an environment or component that expects a NumPy array.
|
||||
|
||||
Attributes:
|
||||
squeeze_batch_dim: If True, removes the first dimension of the array
|
||||
if it is of size 1. This is useful for converting a
|
||||
batched action of size (1, D) to a single action of size (D,).
|
||||
"""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
@@ -36,8 +51,8 @@ class Torch2NumpyActionProcessor(ActionProcessor):
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
# Remove batch dimensions but preserve action dimensions.
|
||||
# Only squeeze if there's a batch dimension (first dim == 1).
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
@@ -48,11 +63,22 @@ class Torch2NumpyActionProcessor(ActionProcessor):
|
||||
|
||||
return numpy_action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessor(ActionProcessor):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
class Numpy2TorchActionProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
Converts a NumPy array action to a PyTorch tensor.
|
||||
|
||||
This step is useful for converting actions from environments or hardware,
|
||||
which are often NumPy arrays, into PyTorch tensors that can be processed
|
||||
by a policy or model.
|
||||
"""
|
||||
|
||||
def action(self, action: np.ndarray) -> torch.Tensor:
|
||||
if not isinstance(action, np.ndarray):
|
||||
@@ -62,3 +88,8 @@ class Numpy2TorchActionProcessor(ActionProcessor):
|
||||
)
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
return torch_action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -1,68 +1,203 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.processor.pipeline import (
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
InfoProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TruncatedProcessorStep,
|
||||
)
|
||||
|
||||
GRIPPER_KEY = "gripper"
|
||||
DISCRETE_PENALTY_KEY = "discrete_penalty"
|
||||
TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasTeleopEvents(Protocol):
|
||||
"""
|
||||
Minimal protocol for objects that provide teleoperation events.
|
||||
|
||||
This protocol defines the `get_teleop_events()` method, allowing processor
|
||||
steps to interact with teleoperators that support event-based controls
|
||||
(like episode termination or success flagging) without needing to know the
|
||||
teleoperator's specific class.
|
||||
"""
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the teleoperator.
|
||||
|
||||
Returns:
|
||||
A dictionary containing control events such as:
|
||||
- `is_intervention`: bool - Whether the human is currently intervening.
|
||||
- `terminate_episode`: bool - Whether to terminate the current episode.
|
||||
- `success`: bool - Whether the episode was successful.
|
||||
- `rerecord_episode`: bool - Whether to rerecord the episode.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Type variable constrained to Teleoperator subclasses that also implement events
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
"""
|
||||
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
|
||||
|
||||
Args:
|
||||
teleop: The teleoperator instance to check.
|
||||
|
||||
Raises:
|
||||
TypeError: If the teleoperator does not have a `get_teleop_events` method.
|
||||
"""
|
||||
if not isinstance(teleop, HasTeleopEvents):
|
||||
raise TypeError(
|
||||
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
|
||||
f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop"
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
|
||||
"""Add teleoperator action to transition complementary data."""
|
||||
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Adds the raw action from a teleoperator to the transition's complementary data.
|
||||
|
||||
This is useful for human-in-the-loop scenarios where the human's input needs to
|
||||
be available to downstream processors, for example, to override a policy's action
|
||||
during an intervention.
|
||||
|
||||
Attributes:
|
||||
teleop_device: The teleoperator instance to get the action from.
|
||||
"""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Retrieves the teleoperator's action and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the teleoperator action added under the
|
||||
`teleop_action` key.
|
||||
"""
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfo(InfoProcessor):
|
||||
"""Add teleoperator control events to transition info."""
|
||||
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
"""
|
||||
Adds teleoperator control events (e.g., terminate, success) to the transition's info.
|
||||
|
||||
teleop_device: Teleoperator
|
||||
This step extracts control events from teleoperators that support event-based
|
||||
interaction, making these signals available to other parts of the system.
|
||||
|
||||
Attributes:
|
||||
teleop_device: An instance of a teleoperator that implements the
|
||||
`HasTeleopEvents` protocol.
|
||||
"""
|
||||
|
||||
teleop_device: TeleopWithEvents
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validates that the provided teleoperator supports events after initialization."""
|
||||
_check_teleop_with_events(self.teleop_device)
|
||||
|
||||
def info(self, info: dict) -> dict:
|
||||
"""
|
||||
Retrieves teleoperator events and updates the info dictionary.
|
||||
|
||||
Args:
|
||||
info: The incoming info dictionary.
|
||||
|
||||
Returns:
|
||||
A new dictionary including the teleoperator events.
|
||||
"""
|
||||
new_info = dict(info)
|
||||
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
|
||||
|
||||
teleop_events = self.teleop_device.get_teleop_events()
|
||||
new_info.update(teleop_events)
|
||||
return new_info
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
class ImageCropResizeProcessor(ObservationProcessor):
|
||||
"""Crop and resize image observations."""
|
||||
class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Crops and/or resizes image observations.
|
||||
|
||||
This step iterates through all image keys in an observation dictionary and applies
|
||||
the specified transformations. It handles device placement, moving tensors to the
|
||||
CPU if necessary for operations not supported on certain accelerators like MPS.
|
||||
|
||||
Attributes:
|
||||
crop_params_dict: A dictionary mapping image keys to cropping parameters
|
||||
(top, left, height, width).
|
||||
resize_size: A tuple (height, width) to resize all images to.
|
||||
"""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
"""
|
||||
Applies cropping and resizing to all images in the observation dictionary.
|
||||
|
||||
Args:
|
||||
observation: The observation dictionary, potentially containing image tensors.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with transformed images.
|
||||
"""
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
@@ -90,29 +225,65 @@ class ImageCropResizeProcessor(ObservationProcessor):
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary with the crop parameters and resize dimensions.
|
||||
"""
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates the image feature shapes in the policy features dictionary if resizing is applied.
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary.
|
||||
|
||||
Returns:
|
||||
The updated policy features dictionary with new image shapes.
|
||||
"""
|
||||
if self.resize_size is None:
|
||||
return features
|
||||
for key in features:
|
||||
for key in features[PipelineFeatureType.OBSERVATION]:
|
||||
if "image" in key:
|
||||
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
|
||||
nb_channel = features[PipelineFeatureType.OBSERVATION][key].shape[0]
|
||||
features[PipelineFeatureType.OBSERVATION][key] = PolicyFeature(
|
||||
type=features[PipelineFeatureType.OBSERVATION][key].type,
|
||||
shape=(nb_channel, *self.resize_size),
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessor(TruncatedProcessor):
|
||||
"""Track episode steps and enforce time limits."""
|
||||
class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
"""
|
||||
Tracks episode steps and enforces a time limit by truncating the episode.
|
||||
|
||||
Attributes:
|
||||
max_episode_steps: The maximum number of steps allowed per episode.
|
||||
current_step: The current step count for the active episode.
|
||||
"""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def truncated(self, truncated):
|
||||
def truncated(self, truncated: bool) -> bool:
|
||||
"""
|
||||
Increments the step counter and sets the truncated flag if the time limit is reached.
|
||||
|
||||
Args:
|
||||
truncated: The incoming truncated flag.
|
||||
|
||||
Returns:
|
||||
True if the episode step limit is reached, otherwise the incoming value.
|
||||
"""
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
@@ -120,24 +291,54 @@ class TimeLimitProcessor(TruncatedProcessor):
|
||||
return truncated
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the `max_episode_steps`.
|
||||
"""
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the step counter, typically called at the start of a new episode."""
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
This step penalizes actions that attempt to close an already closed gripper or
|
||||
open an already open one, based on position thresholds.
|
||||
|
||||
Attributes:
|
||||
penalty: The negative reward value to apply.
|
||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||
"""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
|
||||
Returns:
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
@@ -164,25 +365,57 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the penalty value and max gripper position.
|
||||
"""
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
"""Resets the processor's internal state."""
|
||||
pass
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessor(ProcessorStep):
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
class InterventionActionProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Handles human intervention, overriding policy actions and managing episode termination.
|
||||
|
||||
When an intervention is detected (via teleoperator events in the `info` dict),
|
||||
this step replaces the policy's action with the human's teleoperated action.
|
||||
It also processes signals to terminate the episode or flag success.
|
||||
|
||||
Attributes:
|
||||
use_gripper: Whether to include the gripper in the teleoperated action.
|
||||
terminate_on_success: If True, automatically sets the `done` flag when a
|
||||
`success` event is received.
|
||||
"""
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Processes the transition to handle interventions.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
The modified transition, potentially with an overridden action, updated
|
||||
reward, and termination status.
|
||||
"""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
@@ -238,16 +471,40 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the step's configuration attributes.
|
||||
"""
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessor(ProcessorStep):
|
||||
"""Apply reward classification to image observations."""
|
||||
class RewardClassifierProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a pretrained reward classifier to image observations to predict success.
|
||||
|
||||
This step uses a model to determine if the current state is successful, updating
|
||||
the reward and potentially terminating the episode.
|
||||
|
||||
Attributes:
|
||||
pretrained_path: Path to the pretrained reward classifier model.
|
||||
device: The device to run the classifier on.
|
||||
success_threshold: The probability threshold to consider a prediction as successful.
|
||||
success_reward: The reward value to assign on success.
|
||||
terminate_on_success: If True, terminates the episode upon successful classification.
|
||||
reward_classifier: The loaded classifier model instance.
|
||||
"""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
device: str = "cpu"
|
||||
@@ -258,7 +515,7 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
reward_classifier: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the reward classifier after dataclass initialization."""
|
||||
"""Initializes the reward classifier model after the dataclass is created."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
@@ -267,15 +524,26 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
"""
|
||||
Processes a transition, applying the reward classifier to its image observations.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
The modified transition with an updated reward and done flag based on the
|
||||
classifier's prediction.
|
||||
"""
|
||||
new_transition = transition.copy()
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
return transition
|
||||
return new_transition
|
||||
|
||||
# Extract images from observation
|
||||
images = {key: value for key, value in observation.items() if "image" in key}
|
||||
|
||||
if not images:
|
||||
return transition
|
||||
return new_transition
|
||||
|
||||
# Run reward classifier
|
||||
start_time = time.perf_counter()
|
||||
@@ -285,8 +553,8 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# Calculate reward and termination
|
||||
reward = transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
reward = new_transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = new_transition.get(TransitionKey.DONE, False)
|
||||
|
||||
if math.isclose(success, 1, abs_tol=1e-2):
|
||||
reward = self.success_reward
|
||||
@@ -294,7 +562,6 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
terminated = True
|
||||
|
||||
# Update transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = reward
|
||||
new_transition[TransitionKey.DONE] = terminated
|
||||
|
||||
@@ -306,9 +573,20 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the step's configuration attributes.
|
||||
"""
|
||||
return {
|
||||
"device": self.device,
|
||||
"success_threshold": self.success_threshold,
|
||||
"success_reward": self.success_reward,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -1,11 +1,28 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
@@ -13,19 +30,44 @@ from lerobot.robots import Robot
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessor(ObservationProcessor):
|
||||
"""Add joint velocity information to observations."""
|
||||
class JointVelocityProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Calculates and appends joint velocity information to the observation state.
|
||||
|
||||
This step computes the velocity of each joint by calculating the finite
|
||||
difference between the current and the last observed joint positions. The
|
||||
resulting velocity vector is then concatenated to the original state vector.
|
||||
|
||||
Attributes:
|
||||
dt: The time step (delta time) in seconds between observations, used for
|
||||
calculating velocity.
|
||||
last_joint_positions: Stores the joint positions from the previous step
|
||||
to enable velocity calculation.
|
||||
"""
|
||||
|
||||
dt: float = 0.1
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
"""
|
||||
Computes joint velocities and adds them to the observation state.
|
||||
|
||||
Args:
|
||||
observation: The input observation dictionary, expected to contain
|
||||
an `observation.state` key with joint positions.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with the `observation.state` tensor
|
||||
extended to include joint velocities.
|
||||
|
||||
Raises:
|
||||
ValueError: If `observation.state` is not found in the observation.
|
||||
"""
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get("observation.state")
|
||||
current_positions = observation.get(OBS_STATE)
|
||||
if current_positions is None:
|
||||
# TODO(steven): if we get here, then the transform_features method will not hold
|
||||
return observation
|
||||
raise ValueError(f"{OBS_STATE} is not in observation")
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
@@ -42,46 +84,92 @@ class JointVelocityProcessor(ObservationProcessor):
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the time step `dt`.
|
||||
"""
|
||||
return {
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the internal state, clearing the last known joint positions."""
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features:
|
||||
original_feature = features["observation.state"]
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates the `observation.state` feature to reflect the added velocities.
|
||||
|
||||
This method doubles the size of the first dimension of the `observation.state`
|
||||
shape to account for the concatenation of position and velocity vectors.
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary.
|
||||
|
||||
Returns:
|
||||
The updated policy features dictionary.
|
||||
"""
|
||||
if OBS_STATE in features[PipelineFeatureType.OBSERVATION]:
|
||||
original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
# Double the shape to account for positions + velocities
|
||||
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
|
||||
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature(
|
||||
type=original_feature.type, shape=new_shape
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessor(ObservationProcessor):
|
||||
"""Add motor current information to observations."""
|
||||
class MotorCurrentProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Reads motor currents from a robot and appends them to the observation state.
|
||||
|
||||
This step queries the robot's hardware interface to get the present current
|
||||
for each motor and concatenates this information to the existing state vector.
|
||||
|
||||
Attributes:
|
||||
robot: An instance of a `lerobot` Robot class that provides access to
|
||||
the hardware bus.
|
||||
"""
|
||||
|
||||
robot: Robot | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
"""
|
||||
Fetches motor currents and adds them to the observation state.
|
||||
|
||||
Args:
|
||||
observation: The input observation dictionary.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with the `observation.state` tensor
|
||||
extended to include motor currents.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `robot` attribute has not been set.
|
||||
"""
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
return observation
|
||||
raise ValueError("Robot is not set")
|
||||
|
||||
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get("observation.state")
|
||||
current_state = observation.get(OBS_STATE)
|
||||
if current_state is None:
|
||||
return observation
|
||||
|
||||
@@ -89,15 +177,27 @@ class MotorCurrentProcessor(ObservationProcessor):
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features and self.robot is not None:
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates the `observation.state` feature to reflect the added motor currents.
|
||||
|
||||
original_feature = features["observation.state"]
|
||||
This method increases the size of the first dimension of the `observation.state`
|
||||
shape by the number of motors in the robot.
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary.
|
||||
|
||||
Returns:
|
||||
The updated policy features dictionary.
|
||||
"""
|
||||
if OBS_STATE in features[PipelineFeatureType.OBSERVATION] and self.robot is not None:
|
||||
original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
# Add motor current dimensions to the original state shape
|
||||
num_motors = 0
|
||||
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
|
||||
@@ -105,5 +205,7 @@ class MotorCurrentProcessor(ObservationProcessor):
|
||||
|
||||
if num_motors > 0:
|
||||
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature(
|
||||
type=original_feature.type, shape=new_shape
|
||||
)
|
||||
return features
|
||||
|
||||
@@ -15,16 +15,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
|
||||
A generic script to migrate LeRobot policies with built-in normalization layers to the new
|
||||
pipeline-based processor system.
|
||||
|
||||
This script:
|
||||
1. Loads an existing pretrained policy model
|
||||
2. Extracts normalization statistics from the model
|
||||
3. Creates both preprocessor and postprocessor:
|
||||
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
|
||||
- Postprocessor: unnormalizes outputs (actions) for inference
|
||||
4. Removes normalization layers from the model state_dict
|
||||
5. Saves the new model and both processors
|
||||
This script performs the following steps:
|
||||
1. Loads a pretrained policy model and its configuration from a local path or the
|
||||
Hugging Face Hub.
|
||||
2. Scans the model's state dictionary to extract normalization statistics (e.g., mean,
|
||||
std, min, max) for all features.
|
||||
3. Creates two new processor pipelines:
|
||||
- A preprocessor that normalizes inputs (observations) and outputs (actions).
|
||||
- A postprocessor that unnormalizes outputs (actions) for inference.
|
||||
4. Removes the original normalization layers from the model's state dictionary,
|
||||
creating a "clean" model.
|
||||
5. Saves the new clean model, the preprocessor, the postprocessor, and a generated
|
||||
model card to a new directory.
|
||||
6. Optionally pushes all the new artifacts to the Hugging Face Hub.
|
||||
|
||||
Usage:
|
||||
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||
@@ -46,11 +52,12 @@ from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||
from lerobot.processor.device_processor import DeviceProcessor
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.processor.rename_processor import RenameProcessor
|
||||
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from .pipeline import PolicyProcessorPipeline
|
||||
from .rename_processor import RenameObservationsProcessorStep
|
||||
|
||||
# Policy type to class mapping
|
||||
POLICY_CLASSES = {
|
||||
@@ -67,7 +74,21 @@ POLICY_CLASSES = {
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Extract normalization statistics from model state_dict."""
|
||||
"""
|
||||
Scans a model's state_dict to find and extract normalization statistics.
|
||||
|
||||
This function identifies keys corresponding to normalization layers (e.g., those
|
||||
for mean, std, min, max) based on a set of predefined patterns and organizes
|
||||
them into a nested dictionary.
|
||||
|
||||
Args:
|
||||
state_dict: The state dictionary of a pretrained policy model.
|
||||
|
||||
Returns:
|
||||
A nested dictionary where outer keys are feature names (e.g.,
|
||||
'observation.state') and inner keys are statistic types ('mean', 'std'),
|
||||
mapping to their corresponding tensor values.
|
||||
"""
|
||||
stats = {}
|
||||
|
||||
# Define patterns to match and their prefixes to remove
|
||||
@@ -111,7 +132,25 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
|
||||
def detect_features_and_norm_modes(
|
||||
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
||||
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
||||
"""Detect features and normalization modes from config and stats."""
|
||||
"""
|
||||
Infers policy features and normalization modes from the model config and stats.
|
||||
|
||||
This function first attempts to find feature definitions and normalization
|
||||
mappings directly from the policy's configuration file. If this information is
|
||||
not present, it infers it from the extracted normalization statistics, using
|
||||
tensor shapes to determine feature shapes and the presence of specific stat
|
||||
keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode.
|
||||
It applies sensible defaults if inference is not possible.
|
||||
|
||||
Args:
|
||||
config: The policy's configuration dictionary from `config.json`.
|
||||
stats: The normalization statistics extracted from the model's state_dict.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A dictionary mapping feature names to `PolicyFeature` objects.
|
||||
- A dictionary mapping `FeatureType` enums to `NormalizationMode` enums.
|
||||
"""
|
||||
features = {}
|
||||
norm_modes = {}
|
||||
|
||||
@@ -203,7 +242,19 @@ def detect_features_and_norm_modes(
|
||||
|
||||
|
||||
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Remove normalization layers from state_dict."""
|
||||
"""
|
||||
Creates a new state_dict with all normalization-related layers removed.
|
||||
|
||||
This function filters the original state dictionary, excluding any keys that
|
||||
match a set of predefined patterns associated with normalization modules.
|
||||
|
||||
Args:
|
||||
state_dict: The original model state dictionary.
|
||||
|
||||
Returns:
|
||||
A new state dictionary containing only the core model weights, without
|
||||
any normalization parameters.
|
||||
"""
|
||||
new_state_dict = {}
|
||||
|
||||
# Patterns to remove
|
||||
@@ -227,7 +278,16 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert features from old format to PolicyFeature objects."""
|
||||
"""
|
||||
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
||||
|
||||
Args:
|
||||
features_dict: The feature dictionary in the old format, where values are
|
||||
simple dictionaries (e.g., `{"shape": [7]}`).
|
||||
|
||||
Returns:
|
||||
A dictionary mapping feature names to `PolicyFeature` dataclass objects.
|
||||
"""
|
||||
converted_features = {}
|
||||
|
||||
for key, feature_dict in features_dict.items():
|
||||
@@ -253,8 +313,18 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
"""Load model state_dict and config from hub."""
|
||||
# Download files
|
||||
"""
|
||||
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha').
|
||||
revision: The specific git revision (branch, tag, or commit hash) to use.
|
||||
|
||||
Returns:
|
||||
A tuple containing the model's state dictionary, the policy configuration,
|
||||
and the training configuration.
|
||||
"""
|
||||
# Download files.
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
@@ -403,8 +473,8 @@ def main():
|
||||
# Now create preprocessor and postprocessor with cleaned_config available
|
||||
print("Creating preprocessor and postprocessor...")
|
||||
# The pattern from existing processor factories:
|
||||
# - Preprocessor has two NormalizerProcessors: one for input_features, one for output_features
|
||||
# - Postprocessor has one UnnormalizerProcessor for output_features only
|
||||
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
|
||||
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
|
||||
|
||||
# Get features from cleaned_config (now they're PolicyFeature objects)
|
||||
input_features = cleaned_config.get("input_features", {})
|
||||
@@ -412,23 +482,23 @@ def main():
|
||||
|
||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||
preprocessor_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**input_features, **output_features},
|
||||
norm_map=norm_map,
|
||||
stats=stats,
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=policy_config.device),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=policy_config.device),
|
||||
]
|
||||
preprocessor = RobotProcessor(steps=preprocessor_steps, name="robot_preprocessor")
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
|
||||
|
||||
# Create postprocessor with unnormalizer for outputs only
|
||||
postprocessor_steps = [
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(features=output_features, norm_map=norm_map, stats=stats),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
|
||||
]
|
||||
postprocessor = RobotProcessor(steps=postprocessor_steps, name="robot_postprocessor")
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
|
||||
|
||||
# Determine hub repo ID if pushing to hub
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
@@ -7,16 +24,12 @@ from typing import Any
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.pipeline import (
|
||||
EnvTransition,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -24,22 +37,48 @@ class _NormalizationMixin:
|
||||
"""
|
||||
A mixin class providing core functionality for normalization and unnormalization.
|
||||
|
||||
This class manages normalization statistics, their conversion to tensors, device placement,
|
||||
and the application of normalization transformations. It is designed to be inherited by
|
||||
concrete ProcessorStep implementations.
|
||||
This class manages normalization statistics (`stats`), converts them to tensors for
|
||||
efficient computation, handles device placement, and implements the logic for
|
||||
applying normalization transformations (mean/std and min/max). It is designed to
|
||||
be inherited by concrete `ProcessorStep` implementations and should not be used
|
||||
directly.
|
||||
|
||||
Attributes:
|
||||
features: A dictionary mapping feature names to `PolicyFeature` objects, defining
|
||||
the data structure to be processed.
|
||||
norm_map: A dictionary mapping `FeatureType` to `NormalizationMode`, specifying
|
||||
which normalization method to use for each type of feature.
|
||||
stats: A dictionary containing the normalization statistics (e.g., mean, std,
|
||||
min, max) for each feature.
|
||||
device: The PyTorch device on which to store and perform tensor operations.
|
||||
eps: A small epsilon value to prevent division by zero in normalization
|
||||
calculations.
|
||||
normalize_observation_keys: An optional set of keys to selectively apply
|
||||
normalization to specific observation features.
|
||||
_tensor_stats: An internal dictionary holding the normalization statistics as
|
||||
PyTorch tensors.
|
||||
"""
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
device: torch.device | str | None = None
|
||||
dtype: torch.dtype | None = None
|
||||
eps: float = 1e-8
|
||||
normalize_observation_keys: set[str] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps)
|
||||
"""
|
||||
Initializes the mixin after dataclass construction.
|
||||
|
||||
This method handles the robust deserialization of `features` and `norm_map`
|
||||
from JSON-compatible formats (where enums become strings and tuples become
|
||||
lists) and converts the provided `stats` dictionary into a dictionary of
|
||||
tensors (`_tensor_stats`) on the specified device.
|
||||
"""
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.features:
|
||||
first_val = next(iter(self.features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
@@ -60,15 +99,40 @@ class _NormalizationMixin:
|
||||
|
||||
# Convert stats to tensors and move to the target device once during initialization.
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
|
||||
def to(self, device: torch.device | str) -> _NormalizationMixin:
|
||||
"""Moves the processor's normalization stats to the specified device and returns self."""
|
||||
self.device = device
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
) -> _NormalizationMixin:
|
||||
"""
|
||||
Moves the processor's normalization stats to the specified device.
|
||||
|
||||
Args:
|
||||
device: The target PyTorch device.
|
||||
|
||||
Returns:
|
||||
The instance of the class, allowing for method chaining.
|
||||
"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
"""
|
||||
Returns the normalization statistics as a flat state dictionary.
|
||||
|
||||
All tensors are moved to the CPU before being returned, which is standard practice
|
||||
for saving state dictionaries.
|
||||
|
||||
Returns:
|
||||
A flat dictionary mapping from `'feature_name.stat_name'` to the
|
||||
corresponding statistics tensor on the CPU.
|
||||
"""
|
||||
flat: dict[str, Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
@@ -76,6 +140,15 @@ class _NormalizationMixin:
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: dict[str, Tensor]) -> None:
|
||||
"""
|
||||
Loads normalization statistics from a state dictionary.
|
||||
|
||||
The loaded tensors are moved to the processor's configured device.
|
||||
|
||||
Args:
|
||||
state: A flat state dictionary with keys in the format
|
||||
`'feature_name.stat_name'`.
|
||||
"""
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
@@ -84,7 +157,26 @@ class _NormalizationMixin:
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
|
||||
self.stats = {}
|
||||
for key, tensor_dict in self._tensor_stats.items():
|
||||
self.stats[key] = {}
|
||||
for stat_name, tensor in tensor_dict.items():
|
||||
# Convert tensor back to python/numpy format
|
||||
self.stats[key][stat_name] = from_tensor_to_numpy(tensor)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a serializable dictionary of the processor's configuration.
|
||||
|
||||
This method is used when saving the processor to disk, ensuring that its
|
||||
configuration can be reconstructed later.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable dictionary containing the configuration.
|
||||
"""
|
||||
config = {
|
||||
"eps": self.eps,
|
||||
"features": {
|
||||
@@ -97,24 +189,63 @@ class _NormalizationMixin:
|
||||
return config
|
||||
|
||||
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
|
||||
"""
|
||||
Applies (un)normalization to all relevant features in an observation dictionary.
|
||||
|
||||
Args:
|
||||
observation: The observation dictionary to process.
|
||||
inverse: If `True`, applies unnormalization; otherwise, applies normalization.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with the transformed tensor values.
|
||||
"""
|
||||
new_observation = dict(observation)
|
||||
for key, feature in self.features.items():
|
||||
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
|
||||
continue
|
||||
if feature.type != FeatureType.ACTION and key in new_observation:
|
||||
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32)
|
||||
# Convert to tensor but preserve original dtype for adaptation logic
|
||||
tensor = torch.as_tensor(new_observation[key])
|
||||
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
|
||||
return new_observation
|
||||
|
||||
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
|
||||
tensor = torch.as_tensor(action, dtype=torch.float32)
|
||||
# Convert to tensor but preserve original dtype for adaptation logic
|
||||
"""
|
||||
Applies (un)normalization to an action tensor.
|
||||
|
||||
Args:
|
||||
action: The action tensor to process.
|
||||
inverse: If `True`, applies unnormalization; otherwise, applies normalization.
|
||||
|
||||
Returns:
|
||||
The transformed action tensor.
|
||||
"""
|
||||
tensor = torch.as_tensor(action)
|
||||
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
|
||||
return processed_action
|
||||
|
||||
def _apply_transform(
|
||||
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
|
||||
) -> Tensor:
|
||||
"""Core logic to apply normalization or unnormalization."""
|
||||
"""
|
||||
Core logic to apply a normalization or unnormalization transformation to a tensor.
|
||||
|
||||
This method selects the appropriate normalization mode (e.g., mean/std, min/max)
|
||||
based on the feature type and applies the corresponding mathematical operation.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor to transform.
|
||||
key: The feature key corresponding to the tensor.
|
||||
feature_type: The `FeatureType` of the tensor.
|
||||
inverse: If `True`, applies the inverse transformation (unnormalization).
|
||||
|
||||
Returns:
|
||||
The transformed tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported normalization mode is encountered.
|
||||
"""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
return tensor
|
||||
@@ -122,19 +253,13 @@ class _NormalizationMixin:
|
||||
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# Ensure input tensor is on the same device as the stats.
|
||||
if self.device and tensor.device != self.device:
|
||||
tensor = tensor.to(self.device)
|
||||
# For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor
|
||||
if self._tensor_stats and key in self._tensor_stats:
|
||||
first_stat = next(iter(self._tensor_stats[key].values()))
|
||||
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
|
||||
self.to(device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
# For Accelerate compatibility: move stats to match input tensor device
|
||||
input_device = tensor.device
|
||||
stats = self._tensor_stats[key]
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
|
||||
# Move stats to input device if needed
|
||||
stats_device = next(iter(stats.values())).device
|
||||
if stats_device != input_device:
|
||||
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
@@ -151,7 +276,7 @@ class _NormalizationMixin:
|
||||
# to prevent division by zero. This consistently maps an input equal to
|
||||
# min_val to -1, ensuring a stable transformation.
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom
|
||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||
)
|
||||
if inverse:
|
||||
# Map from [-1, 1] back to [min, max]
|
||||
@@ -165,13 +290,13 @@ class _NormalizationMixin:
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies normalization to observations and actions in a transition.
|
||||
A processor step that applies normalization to observations and actions in a transition.
|
||||
|
||||
This class directly implements the normalization logic for both observation and action
|
||||
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
|
||||
initialization.
|
||||
This class uses the logic from `_NormalizationMixin` to perform forward normalization
|
||||
(e.g., scaling data to have zero mean and unit variance, or to the range [-1, 1]).
|
||||
It is typically used in the pre-processing pipeline before feeding data to a policy.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -184,7 +309,21 @@ class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
normalize_observation_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
device: torch.device | str | None = None,
|
||||
) -> NormalizerProcessor:
|
||||
) -> NormalizerProcessorStep:
|
||||
"""
|
||||
Creates a `NormalizerProcessorStep` instance using statistics from a `LeRobotDataset`.
|
||||
|
||||
Args:
|
||||
dataset: The dataset from which to extract normalization statistics.
|
||||
features: The feature definition for the processor.
|
||||
norm_map: The mapping from feature types to normalization modes.
|
||||
normalize_observation_keys: An optional set of observation keys to normalize.
|
||||
eps: A small epsilon value for numerical stability.
|
||||
device: The target device for the processor.
|
||||
|
||||
Returns:
|
||||
A new instance of `NormalizerProcessorStep`.
|
||||
"""
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
@@ -211,16 +350,22 @@ class NormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies unnormalization (the inverse of normalization) to
|
||||
observations and actions in a transition.
|
||||
A processor step that applies unnormalization to observations and actions.
|
||||
|
||||
This is typically used to transform actions from a normalized policy output back into
|
||||
the original scale for execution in an environment.
|
||||
This class inverts the normalization process, scaling data back to its original
|
||||
range. It is typically used in the post-processing pipeline to convert a policy's
|
||||
normalized action output into a format that can be executed by a robot or
|
||||
environment.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -231,7 +376,19 @@ class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
device: torch.device | str | None = None,
|
||||
) -> UnnormalizerProcessor:
|
||||
) -> UnnormalizerProcessorStep:
|
||||
"""
|
||||
Creates an `UnnormalizerProcessorStep` using statistics from a `LeRobotDataset`.
|
||||
|
||||
Args:
|
||||
dataset: The dataset from which to extract normalization statistics.
|
||||
features: The feature definition for the processor.
|
||||
norm_map: The mapping from feature types to normalization modes.
|
||||
device: The target device for the processor.
|
||||
|
||||
Returns:
|
||||
A new instance of `UnnormalizerProcessorStep`.
|
||||
"""
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -249,20 +406,35 @@ class UnnormalizerProcessor(_NormalizationMixin, ProcessorStep):
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
||||
"""
|
||||
Replaces normalization statistics in a RobotProcessor pipeline.
|
||||
|
||||
This function creates a deep copy of the provided `RobotProcessor` and updates the
|
||||
statistics of any `NormalizerProcessor` or `UnnormalizerProcessor` steps within it.
|
||||
It's useful for adapting a trained policy to a new environment or dataset with
|
||||
different data distributions.
|
||||
def hotswap_stats(
|
||||
policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
|
||||
) -> PolicyProcessorPipeline:
|
||||
"""
|
||||
rp = deepcopy(robot_processor)
|
||||
Replaces normalization statistics in an existing `PolicyProcessorPipeline` instance.
|
||||
|
||||
This function creates a deep copy of the provided pipeline and updates the
|
||||
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` it
|
||||
contains. This is useful for adapting a trained policy to a new environment or
|
||||
dataset with different data distributions without having to reconstruct the entire
|
||||
pipeline.
|
||||
|
||||
Args:
|
||||
policy_processor: The policy processor pipeline to modify.
|
||||
stats: The new dictionary of normalization statistics to apply.
|
||||
|
||||
Returns:
|
||||
A new `PolicyProcessorPipeline` instance with the updated statistics.
|
||||
"""
|
||||
rp = deepcopy(policy_processor)
|
||||
for step in rp.steps:
|
||||
if isinstance(step, _NormalizationMixin):
|
||||
step.stats = stats
|
||||
# Re-initialize tensor_stats on the correct device.
|
||||
step._tensor_stats = to_tensor(stats, device=step.device)
|
||||
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype)
|
||||
return rp
|
||||
|
||||
@@ -20,32 +20,54 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessor(ObservationProcessor):
|
||||
class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes environment observations into the LeRobot format by handling both images and states.
|
||||
Processes standard Gymnasium observations into the LeRobot format.
|
||||
|
||||
Image processing:
|
||||
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
|
||||
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
|
||||
- Adds a batch dimension if missing
|
||||
- Supports single images and image dictionaries
|
||||
This step handles both image and state data from a typical observation dictionary,
|
||||
preparing it for use in a LeRobot policy.
|
||||
|
||||
State processing:
|
||||
- Maps 'environment_state' to observation.environment_state
|
||||
- Maps 'agent_pos' to observation.state
|
||||
- Converts numpy arrays to tensors
|
||||
- Adds a batch dimension if missing
|
||||
**Image Processing:**
|
||||
- Converts channel-last (H, W, C), `uint8` images to channel-first (C, H, W),
|
||||
`float32` tensors.
|
||||
- Normalizes pixel values from the [0, 255] range to [0, 1].
|
||||
- Adds a batch dimension if one is not already present.
|
||||
- Recognizes a single image under the key `"pixels"` and maps it to
|
||||
`"observation.image"`.
|
||||
- Recognizes a dictionary of images under the key `"pixels"` and maps them
|
||||
to `"observation.images.{camera_name}"`.
|
||||
|
||||
**State Processing:**
|
||||
- Maps the `"environment_state"` key to `"observation.environment_state"`.
|
||||
- Maps the `"agent_pos"` key to `"observation.state"`.
|
||||
- Converts NumPy arrays to PyTorch tensors.
|
||||
- Adds a batch dimension if one is not already present.
|
||||
"""
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
"""
|
||||
Processes a single NumPy image array into a channel-first, normalized tensor.
|
||||
|
||||
Args:
|
||||
img: A NumPy array representing the image, expected to be in channel-last
|
||||
(H, W, C) format with a `uint8` dtype.
|
||||
|
||||
Returns:
|
||||
A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with
|
||||
pixel values normalized to the [0, 1] range.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input image does not appear to be in channel-last
|
||||
format or is not of `uint8` dtype.
|
||||
"""
|
||||
# Convert to tensor
|
||||
img_tensor = torch.from_numpy(img)
|
||||
|
||||
@@ -106,18 +128,32 @@ class VanillaObservationProcessor(ObservationProcessor):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- environment_state -> OBS_ENV_STATE,
|
||||
- agent_pos -> OBS_STATE,
|
||||
- observation.environment_state -> OBS_ENV_STATE,
|
||||
- observation.agent_pos -> OBS_STATE
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Transforms feature keys from the Gym standard to the LeRobot standard.
|
||||
|
||||
This method standardizes the feature dictionary by renaming keys according
|
||||
to LeRobot's conventions, ensuring that policies can be constructed correctly.
|
||||
It handles various raw key formats, including those with an "observation." prefix.
|
||||
|
||||
**Renaming Rules:**
|
||||
- `pixels` or `observation.pixels` -> `observation.image`
|
||||
- `pixels.{cam}` or `observation.pixels.{cam}` -> `observation.images.{cam}`
|
||||
- `environment_state` or `observation.environment_state` -> `observation.environment_state`
|
||||
- `agent_pos` or `observation.agent_pos` -> `observation.state`
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary with Gym-style keys.
|
||||
|
||||
Returns:
|
||||
The policy features dictionary with standardized LeRobot keys.
|
||||
"""
|
||||
# Build a new features mapping keyed by the same FeatureType buckets
|
||||
# We assume callers already placed features in the correct FeatureType.
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()}
|
||||
|
||||
exact_pairs = {
|
||||
"pixels": OBS_IMAGE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
@@ -128,29 +164,43 @@ class VanillaObservationProcessor(ObservationProcessor):
|
||||
"pixels.": f"{OBS_IMAGES}.",
|
||||
}
|
||||
|
||||
for key in list(features.keys()):
|
||||
matched_prefix = False
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
# Iterate over all incoming feature buckets and normalize/move each entry
|
||||
for src_ft, bucket in features.items():
|
||||
for key, feat in list(bucket.items()):
|
||||
handled = False
|
||||
|
||||
if key.startswith(old_prefix):
|
||||
suffix = key[len(old_prefix) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
|
||||
if matched_prefix:
|
||||
continue
|
||||
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
if key in features:
|
||||
features[new] = features.pop(key)
|
||||
# Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
new_key = f"{new_prefix}{suffix}"
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
break
|
||||
|
||||
return features
|
||||
if key.startswith(old_prefix):
|
||||
suffix = key[len(old_prefix) :]
|
||||
new_key = f"{new_prefix}{suffix}"
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
break
|
||||
|
||||
if handled:
|
||||
continue
|
||||
|
||||
# Exact-name rules (pixels, environment_state, agent_pos)
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
new_key = new
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
break
|
||||
|
||||
if handled:
|
||||
continue
|
||||
|
||||
# Default: keep key in the same source FeatureType bucket
|
||||
new_features[src_ft][key] = feat
|
||||
|
||||
return new_features
|
||||
|
||||
+132
-215
@@ -22,48 +22,22 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, TypedDict, TypeVar, cast
|
||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .converters import batch_to_transition, create_transition, transition_to_batch
|
||||
from .core import EnvTransition, TransitionKey
|
||||
|
||||
# Type variable for generic processor output type
|
||||
TOutput = TypeVar("TOutput")
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
# TODO(Steven): Use consts
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
DONE = "done"
|
||||
TRUNCATED = "truncated"
|
||||
INFO = "info"
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
"EnvTransition",
|
||||
{
|
||||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||
TransitionKey.INFO.value: dict[str, Any] | None,
|
||||
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ProcessorStepRegistry:
|
||||
"""Registry for processor steps that enables saving/loading by name instead of module path."""
|
||||
|
||||
@@ -142,7 +116,7 @@ class ProcessorStep(ABC):
|
||||
A step is any callable accepting a full `EnvTransition` dict and
|
||||
returning a (possibly modified) dict of the same structure. Implementers
|
||||
are encouraged—but not required—to expose the optional helper methods
|
||||
listed below. When present, these hooks let `RobotProcessor`
|
||||
listed below. When present, these hooks let `DataProcessorPipeline`
|
||||
automatically serialise the step's configuration and learnable state using
|
||||
a safe-to-share JSON + SafeTensors format.
|
||||
|
||||
@@ -194,107 +168,22 @@ class ProcessorStep(ABC):
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
# TODO(Steven): Consider making this abstract so it is more explicit
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
@abstractmethod
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401
|
||||
"""Convert a *batch* dict coming from Learobot replay/dataset code into an
|
||||
``EnvTransition`` dictionary.
|
||||
|
||||
The function maps well known keys to the EnvTransition structure. Missing keys are
|
||||
filled with sane defaults (``None`` or ``0.0``/``False``).
|
||||
|
||||
Keys recognised (case-sensitive):
|
||||
|
||||
* "observation.*" (keys starting with "observation." are grouped into observation dict)
|
||||
* "action"
|
||||
* "next.reward"
|
||||
* "next.done"
|
||||
* "next.truncated"
|
||||
* "info"
|
||||
|
||||
Additional keys are ignored so that existing dataloaders can carry extra
|
||||
metadata without breaking the processor.
|
||||
"""
|
||||
|
||||
# Validate input type
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
# Extract observation keys
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation = observation_keys if observation_keys else None
|
||||
|
||||
# Extract padding, task, index, and task_index keys for complementary data
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
complementary_data = (
|
||||
{**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
if pad_keys or task_key or index_key or task_index_key
|
||||
else {}
|
||||
)
|
||||
|
||||
transition: EnvTransition = {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: batch.get("action"),
|
||||
TransitionKey.REWARD: batch.get("next.reward", 0.0),
|
||||
TransitionKey.DONE: batch.get("next.done", False),
|
||||
TransitionKey.TRUNCATED: batch.get("next.truncated", False),
|
||||
TransitionKey.INFO: batch.get("info", {}),
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
return transition
|
||||
|
||||
|
||||
def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401
|
||||
"""Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with
|
||||
the canonical field names used throughout *LeRobot*.
|
||||
"""
|
||||
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add padding, task, index, and task_index data from complementary_data
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data:
|
||||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||||
batch.update(pad_data)
|
||||
|
||||
if "task" in complementary_data:
|
||||
batch["task"] = complementary_data["task"]
|
||||
|
||||
if "index" in complementary_data:
|
||||
batch["index"] = complementary_data["index"]
|
||||
|
||||
if "task_index" in complementary_data:
|
||||
batch["task_index"] = complementary_data["task_index"]
|
||||
|
||||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class ProcessorKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for RobotProcessor constructor."""
|
||||
"""Keyword arguments for DataProcessorPipeline constructor."""
|
||||
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] | None
|
||||
to_output: Callable[[EnvTransition], Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
class DataProcessorPipeline(ModelHubMixin, Generic[TOutput]):
|
||||
"""
|
||||
Composable, debuggable post-processing processor for robot transitions.
|
||||
|
||||
@@ -308,7 +197,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
Args:
|
||||
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
|
||||
name: Human-readable identifier that is persisted inside the JSON config.
|
||||
Defaults to "RobotProcessor".
|
||||
Defaults to "DataProcessorPipeline".
|
||||
to_transition: Function to convert batch dict to EnvTransition dict.
|
||||
Defaults to _default_batch_to_transition.
|
||||
to_output: Function to convert EnvTransition dict to the desired output format of type TOutput.
|
||||
@@ -322,18 +211,20 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
Type Safety Examples:
|
||||
```python
|
||||
# Default behavior - returns batch dict
|
||||
processor: RobotProcessor[dict[str, Any]] = RobotProcessor(steps=[some_step1, some_step2])
|
||||
processor: DataProcessorPipeline[dict[str, Any]] = DataProcessorPipeline(
|
||||
steps=[some_step1, some_step2]
|
||||
)
|
||||
result: dict[str, Any] = processor(batch_data) # Type checker knows this is a dict
|
||||
|
||||
# For EnvTransition output, explicitly specify identity function
|
||||
transition_processor: RobotProcessor[EnvTransition] = RobotProcessor(
|
||||
transition_processor: DataProcessorPipeline[EnvTransition] = DataProcessorPipeline(
|
||||
steps=[some_step1, some_step2],
|
||||
to_output=lambda x: x, # Identity function
|
||||
)
|
||||
result: EnvTransition = transition_processor(batch_data) # Type checker knows this is EnvTransition
|
||||
|
||||
# For custom output types
|
||||
processor: RobotProcessor[str] = RobotProcessor(
|
||||
processor: DataProcessorPipeline[str] = DataProcessorPipeline(
|
||||
steps=[custom_step], to_output=lambda t: f"Processed {len(t)} keys"
|
||||
)
|
||||
result: str = processor(batch_data) # Type checker knows this is str
|
||||
@@ -355,17 +246,15 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
"""
|
||||
|
||||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||||
name: str = "RobotProcessor"
|
||||
name: str = "DataProcessorPipeline"
|
||||
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
||||
default_factory=lambda: _default_batch_to_transition, repr=False
|
||||
)
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(default=batch_to_transition, repr=False)
|
||||
to_output: Callable[[EnvTransition], TOutput] = field(
|
||||
# Cast is necessary here: Working around Python type-checker limitation.
|
||||
# _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput
|
||||
# for the generic to work. When no explicit type is given, TOutput defaults to dict[str, Any],
|
||||
# making this cast safe.
|
||||
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
|
||||
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
repr=False,
|
||||
)
|
||||
|
||||
@@ -390,6 +279,12 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
# Always convert input through to_transition
|
||||
transition = self.to_transition(data)
|
||||
|
||||
transformed_transition = self._forward(transition)
|
||||
|
||||
# Always use to_output for consistent typing
|
||||
return self.to_output(transformed_transition)
|
||||
|
||||
def _forward(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Process through all steps
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
# Apply before hooks
|
||||
@@ -402,9 +297,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
# Apply after hooks
|
||||
for hook in self.after_step_hooks:
|
||||
hook(idx, transition)
|
||||
|
||||
# Always use to_output for consistent typing
|
||||
return self.to_output(transition)
|
||||
return transition
|
||||
|
||||
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
@@ -529,7 +422,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
**kwargs,
|
||||
) -> RobotProcessor[TOutput]:
|
||||
) -> DataProcessorPipeline[TOutput]:
|
||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||||
|
||||
Args:
|
||||
@@ -537,7 +430,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
(e.g., "username/processor-name").
|
||||
config_filename: Optional specific config filename to load. If not provided, will:
|
||||
- For local paths: look for any .json file in the directory (error if multiple found)
|
||||
- For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json")
|
||||
- For HF Hub: REQUIRED - you must specify the exact config filename
|
||||
overrides: Optional dictionary mapping step names to configuration overrides.
|
||||
Keys must match exact step class names (for unregistered steps) or registry names
|
||||
(for registered steps). Values are dictionaries containing parameter overrides
|
||||
@@ -550,7 +443,7 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
Use identity function (lambda x: x) for EnvTransition output.
|
||||
|
||||
Returns:
|
||||
A RobotProcessor[TOutput] instance loaded from the saved configuration.
|
||||
A DataProcessorPipeline[TOutput] instance loaded from the saved configuration.
|
||||
|
||||
Raises:
|
||||
ImportError: If a processor step class cannot be loaded or imported.
|
||||
@@ -560,13 +453,13 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
Examples:
|
||||
Basic loading:
|
||||
```python
|
||||
processor = RobotProcessor.from_pretrained("path/to/processor")
|
||||
processor = DataProcessorPipeline.from_pretrained("path/to/processor")
|
||||
```
|
||||
|
||||
Loading specific config file:
|
||||
Loading from HF Hub (config_filename required):
|
||||
```python
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
"username/multi-processor-repo", config_filename="preprocessor.json"
|
||||
processor = DataProcessorPipeline.from_pretrained(
|
||||
"username/processor-repo", config_filename="processor.json"
|
||||
)
|
||||
```
|
||||
|
||||
@@ -575,14 +468,14 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
import gym
|
||||
|
||||
env = gym.make("CartPole-v1")
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
processor = DataProcessorPipeline.from_pretrained(
|
||||
"username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}}
|
||||
)
|
||||
```
|
||||
|
||||
Multiple overrides:
|
||||
```python
|
||||
processor = RobotProcessor.from_pretrained(
|
||||
processor = DataProcessorPipeline.from_pretrained(
|
||||
"path/to/processor",
|
||||
overrides={
|
||||
"CustomStep": {"param1": "new_value"},
|
||||
@@ -594,7 +487,19 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
# Use the local variable name 'source' for clarity
|
||||
source = str(pretrained_model_name_or_path)
|
||||
|
||||
if Path(source).is_dir():
|
||||
# Check if it's a local path (either exists or looks like a filesystem path)
|
||||
# Hub repositories are typically in the format "username/repo-name" (exactly one slash)
|
||||
# Local paths are absolute paths, relative paths, or have more complex path structure
|
||||
is_local_path = (
|
||||
Path(source).is_dir()
|
||||
or Path(source).is_absolute()
|
||||
or source.startswith("./")
|
||||
or source.startswith("../")
|
||||
or source.count("/") > 1 # More than one slash suggests local path, not Hub repo
|
||||
or "\\" in source # Windows-style paths are definitely local
|
||||
)
|
||||
|
||||
if is_local_path:
|
||||
# Local path - use it directly
|
||||
base_path = Path(source)
|
||||
|
||||
@@ -613,57 +518,26 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
with open(base_path / config_filename) as file_pointer:
|
||||
loaded_config: dict[str, Any] = json.load(file_pointer)
|
||||
else:
|
||||
# Hugging Face Hub - download all required files
|
||||
# Hugging Face Hub - download specific config file
|
||||
if config_filename is None:
|
||||
# Try common config names
|
||||
common_names = [
|
||||
"robot_processor.json",
|
||||
"robot_preprocessor.json",
|
||||
"robot_postprocessor.json",
|
||||
]
|
||||
config_path = None
|
||||
for name in common_names:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
source,
|
||||
name,
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
config_filename = name
|
||||
break
|
||||
except (FileNotFoundError, OSError, HfHubHTTPError):
|
||||
# FileNotFoundError: local file issues
|
||||
# OSError: network/system errors
|
||||
# HfHubHTTPError: file not found on Hub (404) or other HTTP errors
|
||||
continue
|
||||
|
||||
if config_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No processor configuration file found in {source}. "
|
||||
f"Tried: {common_names}. Please specify the config_filename parameter."
|
||||
)
|
||||
else:
|
||||
# Download specific config file
|
||||
config_path = hf_hub_download(
|
||||
source,
|
||||
config_filename,
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
raise ValueError(
|
||||
f"For Hugging Face Hub repositories ({source}), you must specify the config_filename parameter. "
|
||||
f"Example: DataProcessorPipeline.from_pretrained('{source}', config_filename='processor.json')"
|
||||
)
|
||||
|
||||
config_path = hf_hub_download(
|
||||
source,
|
||||
config_filename,
|
||||
repo_type="model",
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
with open(config_path) as file_pointer:
|
||||
loaded_config = json.load(file_pointer)
|
||||
|
||||
@@ -766,25 +640,25 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
|
||||
return cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "RobotProcessor"),
|
||||
to_transition=to_transition or _default_batch_to_transition,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or batch_to_transition,
|
||||
# Cast is necessary here: Same type-checker limitation as above.
|
||||
# When to_output is None, we use the default which returns dict[str, Any].
|
||||
# The cast ensures type consistency with the generic TOutput parameter.
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of steps in the processor."""
|
||||
return len(self.steps)
|
||||
|
||||
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor[TOutput]:
|
||||
def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TOutput]:
|
||||
"""Indexing helper exposing underlying steps.
|
||||
* ``int`` – returns the idx-th ProcessorStep.
|
||||
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
||||
* ``slice`` – returns a new DataProcessorPipeline with the sliced steps.
|
||||
"""
|
||||
if isinstance(idx, slice):
|
||||
return RobotProcessor(
|
||||
return DataProcessorPipeline(
|
||||
steps=self.steps[idx],
|
||||
name=self.name,
|
||||
to_transition=self.to_transition,
|
||||
@@ -855,30 +729,68 @@ class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
|
||||
parts = [f"name='{self.name}'", steps_repr]
|
||||
|
||||
return f"RobotProcessor({', '.join(parts)})"
|
||||
return f"DataProcessorPipeline({', '.join(parts)})"
|
||||
|
||||
def __post_init__(self):
|
||||
for i, step in enumerate(self.steps):
|
||||
if not callable(step):
|
||||
# TODO(steven): This should instead check isinstance(step, ProcessorStep), test need to be updated
|
||||
raise TypeError(
|
||||
f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition"
|
||||
)
|
||||
if not isinstance(step, ProcessorStep):
|
||||
raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep")
|
||||
|
||||
def transform_features(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, initial_features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Apply ALL steps in order. Only if a step has a features method, it will be called.
|
||||
We aggregate the dataset features of all steps.
|
||||
"""
|
||||
features: dict[str, PolicyFeature] = deepcopy(initial_features)
|
||||
features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = deepcopy(initial_features)
|
||||
|
||||
for _, step in enumerate(self.steps):
|
||||
out = step.transform_features(features)
|
||||
features = out
|
||||
return features
|
||||
|
||||
def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
transition: EnvTransition = create_transition(observation=observation)
|
||||
transformed_transition = self._forward(transition)
|
||||
return transformed_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
class ObservationProcessor(ProcessorStep, ABC):
|
||||
def process_action(self, action: Any | torch.Tensor) -> Any | torch.Tensor:
|
||||
transition: EnvTransition = create_transition(action=action)
|
||||
transformed_transition = self._forward(transition)
|
||||
return transformed_transition[TransitionKey.ACTION]
|
||||
|
||||
def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor:
|
||||
transition: EnvTransition = create_transition(reward=reward)
|
||||
transformed_transition = self._forward(transition)
|
||||
return transformed_transition[TransitionKey.REWARD]
|
||||
|
||||
def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor:
|
||||
transition: EnvTransition = create_transition(done=done)
|
||||
transformed_transition = self._forward(transition)
|
||||
return transformed_transition[TransitionKey.DONE]
|
||||
|
||||
def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor:
|
||||
transition: EnvTransition = create_transition(truncated=truncated)
|
||||
transformed_transition = self._forward(transition)
|
||||
return transformed_transition[TransitionKey.TRUNCATED]
|
||||
|
||||
def process_info(self, info: dict[str, Any]) -> dict[str, Any]:
|
||||
transition: EnvTransition = create_transition(info=info)
|
||||
transformed_transition = self._forward(transition)
|
||||
return transformed_transition[TransitionKey.INFO]
|
||||
|
||||
def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]:
|
||||
transition: EnvTransition = create_transition(complementary_data=complementary_data)
|
||||
transformed_transition = self._forward(transition)
|
||||
return transformed_transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
|
||||
RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TOutput]
|
||||
PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TOutput]
|
||||
|
||||
|
||||
class ObservationProcessorStep(ProcessorStep, ABC):
|
||||
"""Base class for processors that modify only the observation component of a transition.
|
||||
|
||||
Subclasses should override the `observation` method to implement custom observation processing.
|
||||
@@ -924,7 +836,7 @@ class ObservationProcessor(ProcessorStep, ABC):
|
||||
return new_transition
|
||||
|
||||
|
||||
class ActionProcessor(ProcessorStep, ABC):
|
||||
class ActionProcessorStep(ProcessorStep, ABC):
|
||||
"""Base class for processors that modify only the action component of a transition.
|
||||
|
||||
Subclasses should override the `action` method to implement custom action processing.
|
||||
@@ -971,7 +883,7 @@ class ActionProcessor(ProcessorStep, ABC):
|
||||
return new_transition
|
||||
|
||||
|
||||
class RewardProcessor(ProcessorStep, ABC):
|
||||
class RewardProcessorStep(ProcessorStep, ABC):
|
||||
"""Base class for processors that modify only the reward component of a transition.
|
||||
|
||||
Subclasses should override the `reward` method to implement custom reward processing.
|
||||
@@ -1017,7 +929,7 @@ class RewardProcessor(ProcessorStep, ABC):
|
||||
return new_transition
|
||||
|
||||
|
||||
class DoneProcessor(ProcessorStep, ABC):
|
||||
class DoneProcessorStep(ProcessorStep, ABC):
|
||||
"""Base class for processors that modify only the done flag of a transition.
|
||||
|
||||
Subclasses should override the `done` method to implement custom done flag processing.
|
||||
@@ -1068,7 +980,7 @@ class DoneProcessor(ProcessorStep, ABC):
|
||||
return new_transition
|
||||
|
||||
|
||||
class TruncatedProcessor(ProcessorStep, ABC):
|
||||
class TruncatedProcessorStep(ProcessorStep, ABC):
|
||||
"""Base class for processors that modify only the truncated flag of a transition.
|
||||
|
||||
Subclasses should override the `truncated` method to implement custom truncated flag processing.
|
||||
@@ -1115,7 +1027,7 @@ class TruncatedProcessor(ProcessorStep, ABC):
|
||||
return new_transition
|
||||
|
||||
|
||||
class InfoProcessor(ProcessorStep, ABC):
|
||||
class InfoProcessorStep(ProcessorStep, ABC):
|
||||
"""Base class for processors that modify only the info dictionary of a transition.
|
||||
|
||||
Subclasses should override the `info` method to implement custom info processing.
|
||||
@@ -1167,7 +1079,7 @@ class InfoProcessor(ProcessorStep, ABC):
|
||||
return new_transition
|
||||
|
||||
|
||||
class ComplementaryDataProcessor(ProcessorStep, ABC):
|
||||
class ComplementaryDataProcessorStep(ProcessorStep, ABC):
|
||||
"""Base class for processors that modify only the complementary data of a transition.
|
||||
|
||||
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
|
||||
@@ -1200,8 +1112,13 @@ class ComplementaryDataProcessor(ProcessorStep, ABC):
|
||||
return new_transition
|
||||
|
||||
|
||||
class IdentityProcessor(ProcessorStep):
|
||||
class IdentityProcessorStep(ProcessorStep):
|
||||
"""Identity processor that does nothing."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -17,17 +17,26 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rename_processor")
|
||||
class RenameProcessor(ObservationProcessor):
|
||||
"""Rename processor that renames keys in the observation."""
|
||||
@ProcessorStepRegistry.register(name="rename_observations_processor")
|
||||
class RenameObservationsProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
A processor step that renames keys in an observation dictionary.
|
||||
|
||||
This step is useful for creating a standardized data interface by mapping keys
|
||||
from an environment's format to the format expected by a LeRobot policy or
|
||||
other downstream components.
|
||||
|
||||
Attributes:
|
||||
rename_map: A dictionary mapping from old key names to new key names.
|
||||
Keys present in an observation that are not in this map will
|
||||
be kept with their original names.
|
||||
"""
|
||||
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@@ -44,16 +53,37 @@ class RenameProcessor(ObservationProcessor):
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
"""
|
||||
return {self.rename_map.get(k, k): v for k, v in features.items()}
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = features.copy()
|
||||
new_features[PipelineFeatureType.OBSERVATION] = {
|
||||
self.rename_map.get(k, k): v for k, v in features[PipelineFeatureType.OBSERVATION].items()
|
||||
}
|
||||
return new_features
|
||||
|
||||
|
||||
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
|
||||
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
|
||||
"""
|
||||
Renames the top-level keys in a statistics dictionary using a provided mapping.
|
||||
|
||||
This is a helper function typically used to keep normalization statistics
|
||||
consistent with renamed observation or action features. It performs a defensive
|
||||
deep copy to avoid modifying the original `stats` dictionary.
|
||||
|
||||
Args:
|
||||
stats: A nested dictionary of statistics, where top-level keys are
|
||||
feature names (e.g., `{"observation.state": {"mean": 0.5}}`).
|
||||
rename_map: A dictionary mapping old feature names to new feature names.
|
||||
|
||||
Returns:
|
||||
A new statistics dictionary with its top-level keys renamed. Returns an
|
||||
empty dictionary if the input `stats` is empty.
|
||||
"""
|
||||
if not stats:
|
||||
return {}
|
||||
renamed: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@@ -1,5 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tokenizer processor for handling text tokenization in robot transitions.
|
||||
This script defines a processor for tokenizing natural language instructions from an environment transition.
|
||||
|
||||
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
|
||||
token IDs and attention masks, which are then added to the observation dictionary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -9,16 +28,14 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.processor.pipeline import (
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
else:
|
||||
@@ -27,68 +44,62 @@ else:
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="tokenizer_processor")
|
||||
class TokenizerProcessor(ObservationProcessor):
|
||||
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
|
||||
class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processor step to tokenize a natural language task description.
|
||||
|
||||
This processor handles tokenization of task strings found in the complementary_data
|
||||
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
|
||||
to the observation data for model processing while preserving the original task string.
|
||||
This step extracts a task string from the `complementary_data` of an `EnvTransition`,
|
||||
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
||||
token IDs and attention mask to the `observation` dictionary.
|
||||
|
||||
The processor supports both single strings and lists of strings as task inputs.
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Args:
|
||||
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
|
||||
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
|
||||
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
|
||||
tokenizer: A tokenizer object (e.g., from transformers library) that implements
|
||||
the __call__ method. If provided, tokenizer_name is ignored. This parameter
|
||||
is not serialized and must be provided via overrides when loading.
|
||||
max_length: Maximum sequence length for tokenization. Defaults to 512.
|
||||
task_key: Key in complementary_data containing the task text. Defaults to "task".
|
||||
padding: Padding strategy for tokenization. Defaults to "max_length".
|
||||
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
|
||||
|
||||
Examples:
|
||||
Using tokenizer name (auto-loaded):
|
||||
```python
|
||||
processor = TokenizerProcessor(tokenizer_name="bert-base-uncased", max_length=128)
|
||||
```
|
||||
|
||||
Using custom tokenizer object:
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
processor = TokenizerProcessor(tokenizer=custom_tokenizer, max_length=128)
|
||||
```
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained tokenizer from the Hugging Face Hub (e.g., "bert-base-uncased").
|
||||
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
max_length: The maximum length to pad or truncate sequences to.
|
||||
task_key: The key in `complementary_data` where the task string is stored.
|
||||
padding_side: The side to pad on ('left' or 'right').
|
||||
padding: The padding strategy ('max_length', 'longest', etc.).
|
||||
truncation: Whether to truncate sequences longer than `max_length`.
|
||||
input_tokenizer: The internal tokenizer instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
|
||||
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
|
||||
# Internal tokenizer instance (not serialized)
|
||||
_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
input_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
||||
"""
|
||||
Initializes the tokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessor."
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self._tokenizer = self.tokenizer
|
||||
self.input_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("AutoTokenizer is not available")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
@@ -96,13 +107,14 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
)
|
||||
|
||||
def get_task(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""Extract and normalize task from complementary data.
|
||||
"""
|
||||
Extracts the task description(s) from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data.
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
List of task strings if task is present, None otherwise.
|
||||
A list of task strings, or None if the task key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
@@ -115,7 +127,7 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
# Convert to list of strings
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(task, str):
|
||||
return [task]
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
@@ -123,80 +135,82 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation):
|
||||
"""Process the transition by tokenizing the task text.
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
|
||||
This method retrieves the task, tokenizes it, moves the resulting tensors to the
|
||||
same device as other data in the transition, and updates the observation.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data with task text.
|
||||
observation: The original observation dictionary.
|
||||
|
||||
Returns:
|
||||
Modified transition with tokenized task added to observation.
|
||||
|
||||
Raises:
|
||||
ValueError: If tokenizer initialization failed.
|
||||
The updated observation dictionary including token IDs and an attention mask.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
return observation
|
||||
|
||||
# Tokenize the task (creates CPU tensors)
|
||||
# Tokenize the task (this will create CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Detect device from existing tensors in the transition
|
||||
# Detect the device from existing tensors in the transition to ensure consistency
|
||||
target_device = self._detect_device(self.transition)
|
||||
|
||||
# Move tokenized tensors to match the device of other data
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_prompt.items()
|
||||
}
|
||||
|
||||
# Get or create observation dict
|
||||
# Create a new observation dict to avoid modifying the original in place
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Add tokenized data to observation
|
||||
# Add tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
"""Detect device from existing tensors in the transition.
|
||||
"""
|
||||
Detects the torch.device from existing tensors in the transition.
|
||||
|
||||
This allows the tokenized tensors to match the device of other data,
|
||||
which is especially important for multi-GPU training with Accelerate.
|
||||
It checks tensors in the observation dictionary first, then the action tensor.
|
||||
|
||||
Args:
|
||||
transition: The transition to search for existing tensors.
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
The device of the first tensor found, or None if no tensors exist.
|
||||
The detected `torch.device`, or None if no tensors are found.
|
||||
"""
|
||||
# Check observation tensors first (most likely to exist)
|
||||
# Check observation tensors first (most likely place to find tensors)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation:
|
||||
for value in observation.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check action tensor
|
||||
# Fallback to checking the action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.device
|
||||
|
||||
return None # No tensors found, keep on CPU
|
||||
return None # No tensors found, default will be CPU
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize text using the configured tokenizer.
|
||||
"""
|
||||
A wrapper around the tokenizer call.
|
||||
|
||||
Args:
|
||||
text: Text string or list of strings to tokenize.
|
||||
text: A string or list of strings to tokenize.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
|
||||
A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors.
|
||||
"""
|
||||
return self._tokenizer(
|
||||
return self.input_tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
truncation=self.truncation,
|
||||
@@ -206,10 +220,14 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization.
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: Only tokenizer_name is saved, not the tokenizer object itself.
|
||||
When loading, provide the tokenizer via overrides if needed.
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
@@ -219,30 +237,36 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
"truncation": self.truncation,
|
||||
}
|
||||
|
||||
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
|
||||
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract.
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Adds feature definitions for the language tokens and attention mask.
|
||||
|
||||
This updates the policy features dictionary to include the new data added to the
|
||||
observation, ensuring downstream components are aware of their shape and type.
|
||||
|
||||
Args:
|
||||
features: Input feature dictionary.
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
Updated feature dictionary with tokenized task features added.
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Add features for tokenized output if they don't exist
|
||||
# Standard tokenizer output includes tokens and attention_mask
|
||||
# Add a feature for the token IDs if it doesn't already exist
|
||||
if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_TOKENS not in features:
|
||||
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features:
|
||||
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
# Add a feature for the attention mask if it doesn't already exist
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
|
||||
+62
-24
@@ -62,6 +62,7 @@ import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
@@ -76,14 +77,20 @@ from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import (
|
||||
to_dataset_frame,
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
from lerobot.processor import (
|
||||
EnvTransition,
|
||||
IdentityProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RobotProcessorPipeline,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
identity_transition,
|
||||
observation_to_transition,
|
||||
transition_to_action,
|
||||
transition_to_dataset_frame,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor, TransitionKey
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
@@ -236,23 +243,36 @@ def record_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
preprocessor: RobotProcessor | None = None,
|
||||
postprocessor: RobotProcessor | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
control_time_s: int | None = None,
|
||||
teleop_action_processor: RobotProcessor | None = None, # runs after teleop
|
||||
robot_action_processor: RobotProcessor | None = None, # runs before robot
|
||||
robot_observation_processor: RobotProcessor | None = None, # runs after robot
|
||||
teleop_action_processor: RobotProcessorPipeline[EnvTransition] | None = None, # runs after teleop
|
||||
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None, # runs before robot
|
||||
robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None, # runs after robot
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
):
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
|
||||
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||
teleop_action_processor
|
||||
or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition
|
||||
)
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=lambda tr: tr, to_output=to_output_robot_action
|
||||
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
|
||||
robot_action_processor
|
||||
or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=identity_transition,
|
||||
to_output=transition_to_action,
|
||||
)
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
|
||||
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||
robot_observation_processor
|
||||
or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
)
|
||||
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
@@ -265,7 +285,14 @@ def record_loop(
|
||||
(
|
||||
t
|
||||
for t in teleop
|
||||
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
|
||||
if isinstance(
|
||||
t,
|
||||
(
|
||||
so100_leader.SO100Leader,
|
||||
so101_leader.SO101Leader,
|
||||
koch_leader.KochLeader,
|
||||
),
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -308,7 +335,7 @@ def record_loop(
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
if dataset is not None:
|
||||
observation_frame = to_dataset_frame(
|
||||
observation_frame = transition_to_dataset_frame(
|
||||
obs_transition, dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
|
||||
@@ -334,6 +361,7 @@ def record_loop(
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
# TODO(Steven): This assumes that the processor passed by the user should have identity_transition as to_output.
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
|
||||
elif isinstance(teleop, list):
|
||||
@@ -366,7 +394,7 @@ def record_loop(
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
# If to_dataset_frame is provided, use it to merge the transitions.
|
||||
# If transition_to_dataset_frame is provided, use it to merge the transitions.
|
||||
merged = []
|
||||
if obs_transition is not None: # The observation from the robot
|
||||
merged.append(obs_transition)
|
||||
@@ -374,13 +402,15 @@ def record_loop(
|
||||
merged.append(teleop_transition)
|
||||
if policy_transition is not None: # The action from policy
|
||||
merged.append(policy_transition)
|
||||
frame = to_dataset_frame(
|
||||
frame = transition_to_dataset_frame(
|
||||
merged if len(merged) > 1 else merged[0], dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
if display_data:
|
||||
log_rerun_data([obs_transition, teleop_transition or policy_transition])
|
||||
log_rerun_data(
|
||||
observation=obs_transition.get(TransitionKey.OBSERVATION), action=robot_action_to_send
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -400,7 +430,15 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Add next.* features that are generated during recording
|
||||
transition_features = {
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
"next.truncated": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
dataset_features = {**action_features, **obs_features, **transition_features}
|
||||
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
|
||||
@@ -47,9 +47,8 @@ from pprint import pformat
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.converters import to_output_robot_action, to_transition_teleop_action
|
||||
from lerobot.processor.pipeline import IdentityProcessor
|
||||
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_action
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -57,6 +56,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
reachy2,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
@@ -86,7 +86,7 @@ class ReplayConfig:
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Optional processor for actions before sending to robot
|
||||
robot_action_processor: RobotProcessor | None = None
|
||||
robot_action_processor: RobotProcessorPipeline | None = None
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@@ -95,10 +95,10 @@ def replay(cfg: ReplayConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Initialize robot action processor with default if not provided
|
||||
robot_action_processor = cfg.robot_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=to_transition_teleop_action,
|
||||
to_output=to_output_robot_action, # type: ignore[arg-type]
|
||||
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_action, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Reset processor
|
||||
|
||||
@@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Optional
|
||||
left_arm_disable_torque_on_disconnect: bool = True
|
||||
left_arm_max_relative_target: int | None = None
|
||||
left_arm_max_relative_target: float | dict[str, float] | None = None
|
||||
left_arm_use_degrees: bool = False
|
||||
right_arm_disable_torque_on_disconnect: bool = True
|
||||
right_arm_max_relative_target: int | None = None
|
||||
right_arm_max_relative_target: float | dict[str, float] | None = None
|
||||
right_arm_use_degrees: bool = False
|
||||
|
||||
# cameras (shared between both arms)
|
||||
|
||||
@@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -110,6 +110,7 @@ class KochFollower(Robot):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
@@ -120,7 +121,6 @@ class KochFollower(Robot):
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
|
||||
@@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
from .robot_reachy2 import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Robot,
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.cameras.configs import ColorMode
|
||||
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("reachy2")
|
||||
@dataclass
|
||||
class Reachy2RobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors.
|
||||
max_relative_target: float | None = None
|
||||
|
||||
# IP address of the Reachy 2 robot
|
||||
ip_address: str | None = "localhost"
|
||||
|
||||
# If True, turn_off_smoothly() will be sent to the robot before disconnecting.
|
||||
disable_torque_on_disconnect: bool = False
|
||||
|
||||
# Tag for external commands control
|
||||
# Set to True if you use an external commands system to control the robot,
|
||||
# such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation
|
||||
# If True, robot.send_action() will not send commands to the robot.
|
||||
use_external_commands: bool = False
|
||||
|
||||
# Robot parts
|
||||
# Set to False to not add the corresponding joints part to the robot list of joints.
|
||||
# By default, all parts are set to True.
|
||||
with_mobile_base: bool = True
|
||||
with_l_arm: bool = True
|
||||
with_r_arm: bool = True
|
||||
with_neck: bool = True
|
||||
with_antennas: bool = True
|
||||
|
||||
# Robot cameras
|
||||
# Set to True if you want to use the corresponding cameras in the observations.
|
||||
# By default, only the teleop cameras are used.
|
||||
with_left_teleop_camera: bool = True
|
||||
with_right_teleop_camera: bool = True
|
||||
with_torso_camera: bool = False
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add cameras with same ip_address as the robot
|
||||
if self.with_left_teleop_camera:
|
||||
self.cameras["teleop_left"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="left",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
if self.with_right_teleop_camera:
|
||||
self.cameras["teleop_right"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
image_type="right",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
if self.with_torso_camera:
|
||||
self.cameras["torso_rgb"] = Reachy2CameraConfig(
|
||||
name="depth",
|
||||
image_type="rgb",
|
||||
ip_address=self.ip_address,
|
||||
fps=15,
|
||||
width=640,
|
||||
height=480,
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if not (
|
||||
self.with_mobile_base
|
||||
or self.with_l_arm
|
||||
or self.with_r_arm
|
||||
or self.with_neck
|
||||
or self.with_antennas
|
||||
):
|
||||
raise ValueError(
|
||||
"No Reachy2Robot part used.\n"
|
||||
"At least one part of the robot must be set to True "
|
||||
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
|
||||
)
|
||||
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_NECK_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
}
|
||||
|
||||
REACHY2_ANTENNAS_JOINTS = {
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
}
|
||||
|
||||
REACHY2_R_ARM_JOINTS = {
|
||||
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
|
||||
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
|
||||
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
|
||||
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
|
||||
"r_wrist_roll.pos": "r_arm.wrist.roll",
|
||||
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
|
||||
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
|
||||
"r_gripper.pos": "r_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_L_ARM_JOINTS = {
|
||||
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
|
||||
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
|
||||
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
|
||||
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
|
||||
"l_wrist_roll.pos": "l_arm.wrist.roll",
|
||||
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
|
||||
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "vx",
|
||||
"mobile_base.vy": "vy",
|
||||
"mobile_base.vtheta": "vtheta",
|
||||
}
|
||||
|
||||
|
||||
class Reachy2Robot(Robot):
|
||||
"""
|
||||
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
|
||||
"""
|
||||
|
||||
config_class = Reachy2RobotConfig
|
||||
name = "reachy2"
|
||||
|
||||
def __init__(self, config: Reachy2RobotConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.robot_type = self.config.type
|
||||
self.use_external_commands = self.config.use_external_commands
|
||||
|
||||
self.reachy: None | ReachySDK = None
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.logs: dict[str, float] = {}
|
||||
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict[str, Any]:
|
||||
return {**self.motors_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self.motors_features
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]:
|
||||
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
|
||||
|
||||
@property
|
||||
def motors_features(self) -> dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
self.joints_dict.keys(),
|
||||
float,
|
||||
),
|
||||
**dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
),
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.joints_dict.keys(), float)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
def connect(self, calibrate: bool = False) -> None:
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
|
||||
def configure(self) -> None:
|
||||
if self.reachy is not None:
|
||||
self.reachy.turn_on()
|
||||
self.reachy.reset_default_limits()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def _generate_joints_dict(self) -> dict[str, str]:
|
||||
joints = {}
|
||||
if self.config.with_neck:
|
||||
joints.update(REACHY2_NECK_JOINTS)
|
||||
if self.config.with_l_arm:
|
||||
joints.update(REACHY2_L_ARM_JOINTS)
|
||||
if self.config.with_r_arm:
|
||||
joints.update(REACHY2_R_ARM_JOINTS)
|
||||
if self.config.with_antennas:
|
||||
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||
return joints
|
||||
|
||||
def _get_state(self) -> dict[str, float]:
|
||||
if self.reachy is not None:
|
||||
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
|
||||
if not self.config.with_mobile_base:
|
||||
return pos_dict
|
||||
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
return {**pos_dict, **vel_dict}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
# Read Reachy 2 state
|
||||
before_read_t = time.perf_counter()
|
||||
obs_dict.update(self._get_state())
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.reachy is not None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
|
||||
vel = {}
|
||||
goal_pos = {}
|
||||
for key, val in action.items():
|
||||
if key not in self.joints_dict:
|
||||
if key not in REACHY2_VEL:
|
||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
||||
else:
|
||||
vel[REACHY2_VEL[key]] = float(val)
|
||||
else:
|
||||
if not self.use_external_commands and self.config.max_relative_target is not None:
|
||||
goal_pos[key] = float(val)
|
||||
goal_present_pos = {
|
||||
key: (
|
||||
goal_pos[key],
|
||||
self.reachy.joints[self.joints_dict[key]].present_position,
|
||||
)
|
||||
}
|
||||
safe_goal_pos = ensure_safe_goal_position(
|
||||
goal_present_pos, float(self.config.max_relative_target)
|
||||
)
|
||||
val = safe_goal_pos[key]
|
||||
self.reachy.joints[self.joints_dict[key]].goal_position = float(val)
|
||||
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
|
||||
|
||||
# We don't send the goal positions if we control Reachy 2 externally
|
||||
if not self.use_external_commands:
|
||||
self.reachy.send_goal_positions()
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.send_speed_command()
|
||||
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
return action
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.reachy is not None:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
if self.config.disable_torque_on_disconnect:
|
||||
self.reachy.turn_off_smoothly()
|
||||
self.reachy.disconnect()
|
||||
@@ -30,9 +30,9 @@ class SO100FollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# !/usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -17,39 +17,48 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.pipeline import (
|
||||
ActionProcessor,
|
||||
ComplementaryDataProcessor,
|
||||
from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||
@dataclass
|
||||
class EEReferenceAndDelta(ActionProcessor):
|
||||
class EEReferenceAndDelta(ActionProcessorStep):
|
||||
"""
|
||||
Compute the desired end-effector pose from the target pose and the current pose.
|
||||
Computes a target end-effector pose from a relative delta command.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
This step takes a desired change in position and orientation (`target_*`) and applies it to a
|
||||
reference end-effector pose to calculate an absolute target pose. The reference pose is derived
|
||||
from the current robot joint positions using forward kinematics.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
The processor can operate in two modes:
|
||||
1. `use_latched_reference=True`: The reference pose is "latched" or saved at the moment the action
|
||||
is first enabled. Subsequent commands are relative to this fixed reference.
|
||||
2. `use_latched_reference=False`: The reference pose is updated to the robot's current pose at
|
||||
every step.
|
||||
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model for forward kinematics.
|
||||
end_effector_step_sizes: A dictionary scaling the input delta commands.
|
||||
motor_names: A list of motor names required for forward kinematics.
|
||||
use_latched_reference: If True, latch the reference pose on enable; otherwise, always use the
|
||||
current pose as the reference.
|
||||
reference_ee_pose: Internal state storing the latched reference pose.
|
||||
_prev_enabled: Internal state to detect the rising edge of the enable signal.
|
||||
_command_when_disabled: Internal state to hold the last command while disabled.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
@@ -82,13 +91,13 @@ class EEReferenceAndDelta(ActionProcessor):
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
|
||||
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
|
||||
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
|
||||
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
|
||||
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
|
||||
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
|
||||
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
|
||||
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
|
||||
enabled = bool(new_action.pop("enabled", 0))
|
||||
tx = float(new_action.pop("target_x", 0.0))
|
||||
ty = float(new_action.pop("target_y", 0.0))
|
||||
tz = float(new_action.pop("target_z", 0.0))
|
||||
wx = float(new_action.pop("target_wx", 0.0))
|
||||
wy = float(new_action.pop("target_wy", 0.0))
|
||||
wz = float(new_action.pop("target_wz", 0.0))
|
||||
|
||||
desired = None
|
||||
|
||||
@@ -124,54 +133,57 @@ class EEReferenceAndDelta(ActionProcessor):
|
||||
# Write action fields
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
new_action[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
new_action[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
new_action[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
|
||||
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
|
||||
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
|
||||
new_action["ee.x"] = float(pos[0])
|
||||
new_action["ee.y"] = float(pos[1])
|
||||
new_action["ee.z"] = float(pos[2])
|
||||
new_action["ee.wx"] = float(tw[0])
|
||||
new_action["ee.wy"] = float(tw[1])
|
||||
new_action["ee.wz"] = float(tw[2])
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return new_action
|
||||
|
||||
def reset(self):
|
||||
"""Resets the internal state of the processor."""
|
||||
self._prev_enabled = False
|
||||
self.reference_ee_pose = None
|
||||
self._command_when_disabled = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.enabled", None)
|
||||
features.pop(f"{ACTION}.target_x", None)
|
||||
features.pop(f"{ACTION}.target_y", None)
|
||||
features.pop(f"{ACTION}.target_z", None)
|
||||
features.pop(f"{ACTION}.target_wx", None)
|
||||
features.pop(f"{ACTION}.target_wy", None)
|
||||
features.pop(f"{ACTION}.target_wz", None)
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION].pop("enabled", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_x", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_y", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_z", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_wx", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_wy", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_wz", None)
|
||||
|
||||
features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[PipelineFeatureType.ACTION]["ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_bounds_and_safety")
|
||||
@dataclass
|
||||
class EEBoundsAndSafety(ActionProcessor):
|
||||
class EEBoundsAndSafety(ActionProcessorStep):
|
||||
"""
|
||||
Clip the end-effector pose to the bounds and check for jumps.
|
||||
Clips the end-effector pose to predefined bounds and checks for unsafe jumps.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
This step ensures that the target end-effector pose remains within a safe operational workspace.
|
||||
It also moderates the command to prevent large, sudden movements between consecutive steps.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
Attributes:
|
||||
end_effector_bounds: A dictionary with "min" and "max" keys for position clipping.
|
||||
max_ee_step_m: The maximum allowed change in position (in meters) between steps.
|
||||
max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps.
|
||||
_last_pos: Internal state storing the last commanded position.
|
||||
_last_twist: Internal state storing the last commanded orientation.
|
||||
"""
|
||||
|
||||
end_effector_bounds: dict
|
||||
@@ -181,15 +193,17 @@ class EEBoundsAndSafety(ActionProcessor):
|
||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
x = act.get("ee.x", None)
|
||||
y = act.get("ee.y", None)
|
||||
z = act.get("ee.z", None)
|
||||
wx = act.get("ee.wx", None)
|
||||
wy = act.get("ee.wy", None)
|
||||
wz = act.get("ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return act
|
||||
raise ValueError(
|
||||
"Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
|
||||
)
|
||||
|
||||
pos = np.array([x, y, z], dtype=float)
|
||||
twist = np.array([wx, wy, wz], dtype=float)
|
||||
@@ -208,38 +222,40 @@ class EEBoundsAndSafety(ActionProcessor):
|
||||
self._last_pos = pos
|
||||
self._last_twist = twist
|
||||
|
||||
act[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
act[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
act[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
act[f"{ACTION}.ee.wx"] = float(twist[0])
|
||||
act[f"{ACTION}.ee.wy"] = float(twist[1])
|
||||
act[f"{ACTION}.ee.wz"] = float(twist[2])
|
||||
act["ee.x"] = float(pos[0])
|
||||
act["ee.y"] = float(pos[1])
|
||||
act["ee.z"] = float(pos[2])
|
||||
act["ee.wx"] = float(twist[0])
|
||||
act["ee.wy"] = float(twist[1])
|
||||
act["ee.wz"] = float(twist[2])
|
||||
return act
|
||||
|
||||
def reset(self):
|
||||
"""Resets the last known position and orientation."""
|
||||
self._last_pos = None
|
||||
self._last_twist = None
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
||||
@dataclass
|
||||
class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
"""
|
||||
Compute the desired joint positions from the desired end-effector pose.
|
||||
Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
This step translates a Cartesian command (position and orientation of the end-effector) into
|
||||
the corresponding joint-space commands for each motor.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.joint_name_1.pos": float,
|
||||
"action.joint_name_2.pos": float,
|
||||
...
|
||||
"action.joint_name_n.pos": float,
|
||||
}
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model for inverse kinematics.
|
||||
motor_names: A list of motor names for which to compute joint positions.
|
||||
q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver.
|
||||
initial_guess_current_joints: If True, use the robot's current joint state as the IK guess.
|
||||
If False, use the solution from the previous step.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
@@ -248,18 +264,19 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
new_transition = transition.copy()
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
x = act.get("ee.x", None)
|
||||
y = act.get("ee.y", None)
|
||||
z = act.get("ee.z", None)
|
||||
wx = act.get("ee.wx", None)
|
||||
wy = act.get("ee.wy", None)
|
||||
wz = act.get("ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return transition
|
||||
return new_transition
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
@@ -286,23 +303,31 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
new_act = dict(act)
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
|
||||
# TODO(pepijn): Investigate if this is correct
|
||||
# Do we want an observation key in the action field?
|
||||
new_act["gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
new_act[f"{name}.pos"] = float(q_target[i])
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return transition
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION]["gripper.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
for name in self.motor_names:
|
||||
features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
"""Resets the initial guess for the IK solver."""
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@@ -310,17 +335,18 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
@dataclass
|
||||
class GripperVelocityToJoint(ProcessorStep):
|
||||
"""
|
||||
Convert the gripper velocity to a joint velocity.
|
||||
Converts a gripper velocity command into a target gripper joint position.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.gripper": float,
|
||||
}
|
||||
This step integrates a normalized velocity command over time to produce a position command,
|
||||
taking the current gripper position as a starting point. It also supports a discrete mode
|
||||
where integer actions map to open, close, or no-op.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.gripper.pos": float,
|
||||
}
|
||||
Attributes:
|
||||
motor_names: A list of motor names, which must include 'gripper'.
|
||||
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
|
||||
clip_min: The minimum allowed gripper joint position.
|
||||
clip_max: The maximum allowed gripper joint position.
|
||||
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
|
||||
"""
|
||||
|
||||
motor_names: list[str]
|
||||
@@ -330,67 +356,72 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if f"{ACTION}.gripper" not in act:
|
||||
return transition
|
||||
if "gripper" not in act:
|
||||
raise ValueError("Required action key 'gripper' not found in transition")
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
new_act = dict(act)
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
return transition
|
||||
raise ValueError(
|
||||
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
|
||||
)
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
|
||||
gripper_action = act.get("gripper", 1.0)
|
||||
gripper_action = gripper_action - 1.0
|
||||
gripper_action *= self.clip_max
|
||||
act[f"{ACTION}.gripper"] = gripper_action
|
||||
act["gripper"] = gripper_action
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
# Compute desired gripper velocity
|
||||
u = float(act.get(f"{ACTION}.gripper", 0.0))
|
||||
u = float(act.get("gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
|
||||
new_act = dict(act)
|
||||
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
new_act["gripper.pos"] = gripper_pos
|
||||
new_act.pop("gripper", None)
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
return transition
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION].pop("gripper", None)
|
||||
features[PipelineFeatureType.ACTION]["gripper.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.gripper.pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(1,)
|
||||
)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEE(ObservationProcessor):
|
||||
class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
||||
"""
|
||||
Compute the end-effector pose from the joint positions.
|
||||
Computes the end-effector pose from joint positions using forward kinematics (FK).
|
||||
|
||||
Input OBSERVATION keys:
|
||||
{
|
||||
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
|
||||
}
|
||||
This step is typically used to add the robot's Cartesian pose to the observation space,
|
||||
which can be useful for visualization or as an input to a policy.
|
||||
|
||||
Output OBSERVATION keys:
|
||||
{
|
||||
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model.
|
||||
motor_names: A list of motor names whose joint positions are used for FK.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
@@ -398,7 +429,7 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
|
||||
|
||||
def observation(self, obs: dict) -> dict:
|
||||
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
|
||||
return obs
|
||||
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
|
||||
|
||||
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
@@ -413,21 +444,29 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
|
||||
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
|
||||
return obs
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.{k}"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(1,)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_robot_observation")
|
||||
@dataclass
|
||||
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
|
||||
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Read the robot's current observation and insert it into the transition as complementary data.
|
||||
Reads the robot's current observation and adds it to the transition's complementary data.
|
||||
|
||||
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
|
||||
{ "<motor_name>": <float position>, ... }
|
||||
This step acts as a bridge to the physical robot, injecting its real-time sensor readings
|
||||
(like raw joint positions) into the data processing pipeline. This data is then available
|
||||
for other processing steps.
|
||||
|
||||
Attributes:
|
||||
robot: An instance of a `Robot` class used to get observations from hardware.
|
||||
"""
|
||||
|
||||
robot: Robot
|
||||
@@ -444,3 +483,8 @@ class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessor):
|
||||
if isinstance(k, str) and k.endswith(".pos")
|
||||
}
|
||||
return new_comp
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -30,9 +30,9 @@ class SO101FollowerConfig(RobotConfig):
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -24,11 +24,6 @@ from ..config import RobotConfig
|
||||
@RobotConfig.register_subclass("stretch3")
|
||||
@dataclass
|
||||
class Stretch3RobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
|
||||
@@ -57,6 +57,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .bi_so100_follower import BiSO100Follower
|
||||
|
||||
return BiSO100Follower(config)
|
||||
elif config.type == "reachy2":
|
||||
from .reachy2 import Reachy2Robot
|
||||
|
||||
return Reachy2Robot(config)
|
||||
elif config.type == "mock_robot":
|
||||
from tests.mocks.mock_robot import MockRobot
|
||||
|
||||
@@ -67,7 +71,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
|
||||
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
|
||||
def ensure_safe_goal_position(
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float]
|
||||
) -> dict[str, float]:
|
||||
"""Caps relative action target magnitude for safety."""
|
||||
|
||||
|
||||
@@ -28,15 +28,15 @@ class ViperXConfig(RobotConfig):
|
||||
|
||||
# /!\ FOR SAFETY, READ THIS /!\
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor
|
||||
# names to the max_relative_target value for that motor.
|
||||
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
|
||||
# When you feel more confident with teleoperation or running the policy, you can extend
|
||||
# this safety limit and even removing it by setting it to `null`.
|
||||
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
||||
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
||||
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
||||
max_relative_target: int | None = 5
|
||||
max_relative_target: float | dict[str, float] = 5.0
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
+45
-14
@@ -56,6 +56,8 @@ from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
@@ -69,9 +71,11 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.processor.core import TransitionKey
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -84,6 +88,10 @@ from lerobot.utils.utils import (
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
@@ -120,7 +128,6 @@ def rollout(
|
||||
The dictionary described above.
|
||||
"""
|
||||
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
device = get_device_from_parameters(policy)
|
||||
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
@@ -151,19 +158,18 @@ def rollout(
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action = action.to("cpu").numpy()
|
||||
action: np.ndarray = action.to("cpu").numpy()
|
||||
action: np.ndarray = action.to("cpu").numpy()
|
||||
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
@@ -220,6 +226,10 @@ def rollout(
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
videos_dir: Path | None = None,
|
||||
@@ -296,8 +306,14 @@ def eval_policy(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
policy,
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
@@ -479,13 +495,28 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
|
||||
|
||||
policy.eval()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
|
||||
)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy(
|
||||
env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
env=env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
start_seed=cfg.seed,
|
||||
|
||||
@@ -62,7 +62,7 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl.gym_manipulator import (
|
||||
create_transition,
|
||||
@@ -98,9 +98,7 @@ from lerobot.utils.utils import (
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
# Main entry point
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@@ -207,9 +205,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
# Core algorithm functions
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
@@ -406,9 +402,7 @@ def act_with_policy(
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions #
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
@@ -653,9 +647,7 @@ def interactions_stream(
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
#################################################
|
||||
# Policy functions #
|
||||
#################################################
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
@@ -687,9 +679,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities functions #
|
||||
#################################################
|
||||
# Utilities functions
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
|
||||
@@ -29,25 +29,27 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlRobotEnvConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
AddTeleopActionAsComplimentaryData,
|
||||
AddTeleopEventsAsInfo,
|
||||
DeviceProcessor,
|
||||
GripperPenaltyProcessor,
|
||||
ImageCropResizeProcessor,
|
||||
InterventionActionProcessor,
|
||||
JointVelocityProcessor,
|
||||
MapDeltaActionToRobotAction,
|
||||
MapTensorToDeltaActionDict,
|
||||
MotorCurrentProcessor,
|
||||
Numpy2TorchActionProcessor,
|
||||
RewardClassifierProcessor,
|
||||
RobotProcessor,
|
||||
TimeLimitProcessor,
|
||||
ToBatchProcessor,
|
||||
Torch2NumpyActionProcessor,
|
||||
VanillaObservationProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
JointVelocityProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
MapTensorToDeltaActionDictStep,
|
||||
MotorCurrentProcessorStep,
|
||||
Numpy2TorchActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
Torch2NumpyActionProcessorStep,
|
||||
TransitionKey,
|
||||
VanillaObservationProcessorStep,
|
||||
create_transition,
|
||||
)
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
make_robot_from_config,
|
||||
@@ -98,21 +100,6 @@ class GymManipulatorConfig:
|
||||
device: str = "cpu"
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
|
||||
) -> dict[str, Any]:
|
||||
"""Create an EnvTransition dictionary with default values."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info if info is not None else {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None:
|
||||
"""Reset robot arm to target position using smooth trajectory."""
|
||||
current_position_dict = robot_arm.bus.sync_read("Present_Position")
|
||||
@@ -375,19 +362,21 @@ def make_processors(
|
||||
|
||||
if cfg.name == "gym_hil":
|
||||
action_pipeline_steps = [
|
||||
InterventionActionProcessor(terminate_on_success=terminate_on_success),
|
||||
Torch2NumpyActionProcessor(),
|
||||
InterventionActionProcessorStep(terminate_on_success=terminate_on_success),
|
||||
Torch2NumpyActionProcessorStep(),
|
||||
]
|
||||
|
||||
# Minimal processor pipeline for GymHIL simulation
|
||||
env_pipeline_steps = [
|
||||
Numpy2TorchActionProcessor(),
|
||||
VanillaObservationProcessor(),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device=device),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=device),
|
||||
]
|
||||
|
||||
return RobotProcessor(steps=env_pipeline_steps), RobotProcessor(steps=action_pipeline_steps)
|
||||
return DataProcessorPipeline(steps=env_pipeline_steps), DataProcessorPipeline(
|
||||
steps=action_pipeline_steps
|
||||
)
|
||||
|
||||
# Full processor pipeline for real robot environment
|
||||
# Get robot and motor information for kinematics
|
||||
@@ -402,13 +391,13 @@ def make_processors(
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
env_pipeline_steps = [VanillaObservationProcessor()]
|
||||
env_pipeline_steps = [VanillaObservationProcessorStep()]
|
||||
|
||||
if cfg.processor.observation is not None:
|
||||
if cfg.processor.observation.add_joint_velocity_to_observation:
|
||||
env_pipeline_steps.append(JointVelocityProcessor(dt=1.0 / cfg.fps))
|
||||
env_pipeline_steps.append(JointVelocityProcessorStep(dt=1.0 / cfg.fps))
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessor(robot=env.robot))
|
||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||
|
||||
if kinematics_solver is not None:
|
||||
env_pipeline_steps.append(
|
||||
@@ -420,7 +409,7 @@ def make_processors(
|
||||
|
||||
if cfg.processor.image_preprocessing is not None:
|
||||
env_pipeline_steps.append(
|
||||
ImageCropResizeProcessor(
|
||||
ImageCropResizeProcessorStep(
|
||||
crop_params_dict=cfg.processor.image_preprocessing.crop_params_dict,
|
||||
resize_size=cfg.processor.image_preprocessing.resize_size,
|
||||
)
|
||||
@@ -429,13 +418,13 @@ def make_processors(
|
||||
# Add time limit processor if reset config exists
|
||||
if cfg.processor.reset is not None:
|
||||
env_pipeline_steps.append(
|
||||
TimeLimitProcessor(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
|
||||
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
|
||||
)
|
||||
|
||||
# Add gripper penalty processor if gripper config exists and enabled
|
||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessor(
|
||||
GripperPenaltyProcessorStep(
|
||||
penalty=cfg.processor.gripper.gripper_penalty,
|
||||
max_gripper_pos=cfg.processor.max_gripper_pos,
|
||||
)
|
||||
@@ -446,7 +435,7 @@ def make_processors(
|
||||
and cfg.processor.reward_classifier.pretrained_path is not None
|
||||
):
|
||||
env_pipeline_steps.append(
|
||||
RewardClassifierProcessor(
|
||||
RewardClassifierProcessorStep(
|
||||
pretrained_path=cfg.processor.reward_classifier.pretrained_path,
|
||||
device=device,
|
||||
success_threshold=cfg.processor.reward_classifier.success_threshold,
|
||||
@@ -455,14 +444,14 @@ def make_processors(
|
||||
)
|
||||
)
|
||||
|
||||
env_pipeline_steps.append(ToBatchProcessor())
|
||||
env_pipeline_steps.append(DeviceProcessor(device=device))
|
||||
env_pipeline_steps.append(AddBatchDimensionProcessorStep())
|
||||
env_pipeline_steps.append(DeviceProcessorStep(device=device))
|
||||
|
||||
action_pipeline_steps = [
|
||||
AddTeleopActionAsComplimentaryData(teleop_device=teleop_device),
|
||||
AddTeleopEventsAsInfo(teleop_device=teleop_device),
|
||||
AddTeleopActionAsComplimentaryDataStep(teleop_device=teleop_device),
|
||||
AddTeleopEventsAsInfoStep(teleop_device=teleop_device),
|
||||
AddRobotObservationAsComplimentaryData(robot=env.robot),
|
||||
InterventionActionProcessor(
|
||||
InterventionActionProcessorStep(
|
||||
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False,
|
||||
terminate_on_success=terminate_on_success,
|
||||
),
|
||||
@@ -472,8 +461,10 @@ def make_processors(
|
||||
if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None:
|
||||
# Add EE bounds and safety processor
|
||||
inverse_kinematics_steps = [
|
||||
MapTensorToDeltaActionDict(),
|
||||
MapDeltaActionToRobotAction(),
|
||||
MapTensorToDeltaActionDictStep(
|
||||
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False
|
||||
),
|
||||
MapDeltaActionToRobotActionStep(),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes=cfg.processor.inverse_kinematics.end_effector_step_sizes,
|
||||
@@ -497,15 +488,15 @@ def make_processors(
|
||||
]
|
||||
action_pipeline_steps.extend(inverse_kinematics_steps)
|
||||
|
||||
return RobotProcessor(steps=env_pipeline_steps), RobotProcessor(steps=action_pipeline_steps)
|
||||
return DataProcessorPipeline(steps=env_pipeline_steps), DataProcessorPipeline(steps=action_pipeline_steps)
|
||||
|
||||
|
||||
def step_env_and_process_transition(
|
||||
env: gym.Env,
|
||||
transition: EnvTransition,
|
||||
action: torch.Tensor,
|
||||
env_processor: RobotProcessor,
|
||||
action_processor: RobotProcessor,
|
||||
env_processor: DataProcessorPipeline,
|
||||
action_processor: DataProcessorPipeline,
|
||||
):
|
||||
"""
|
||||
Execute one step with processor pipeline.
|
||||
@@ -554,8 +545,8 @@ def step_env_and_process_transition(
|
||||
|
||||
def control_loop(
|
||||
env: gym.Env,
|
||||
env_processor: RobotProcessor,
|
||||
action_processor: RobotProcessor,
|
||||
env_processor: DataProcessorPipeline,
|
||||
action_processor: DataProcessorPipeline,
|
||||
teleop_device: Teleoperator,
|
||||
cfg: GymManipulatorConfig,
|
||||
) -> None:
|
||||
@@ -709,7 +700,9 @@ def control_loop(
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def replay_trajectory(env: gym.Env, action_processor: RobotProcessor, cfg: GymManipulatorConfig) -> None:
|
||||
def replay_trajectory(
|
||||
env: gym.Env, action_processor: DataProcessorPipeline, cfg: GymManipulatorConfig
|
||||
) -> None:
|
||||
"""Replay recorded trajectory on robot environment."""
|
||||
assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay"
|
||||
|
||||
|
||||
@@ -103,11 +103,6 @@ from lerobot.utils.wandb_utils import WandBLogger
|
||||
LOG_PREFIX = "[LEARNER]"
|
||||
|
||||
|
||||
#################################################
|
||||
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
|
||||
#################################################
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainRLServerPipelineConfig):
|
||||
if not use_threads(cfg):
|
||||
@@ -250,9 +245,7 @@ def start_learner_threads(
|
||||
logging.info("[LEARNER] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
# Core algorithm functions
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
@@ -820,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
#################################################
|
||||
# Training setup functions #
|
||||
#################################################
|
||||
# Training setup functions
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig:
|
||||
@@ -1023,9 +1014,7 @@ def initialize_offline_replay_buffer(
|
||||
return offline_replay_buffer
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities/Helpers functions #
|
||||
#################################################
|
||||
# Utilities/Helpers functions
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
|
||||
@@ -26,7 +26,7 @@ from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
@@ -65,6 +65,28 @@ def update_policy(
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
||||
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
|
||||
|
||||
Args:
|
||||
train_metrics: A MetricsTracker instance to record training statistics.
|
||||
policy: The policy model to be trained.
|
||||
batch: A batch of training data.
|
||||
optimizer: The optimizer used to update the policy's parameters.
|
||||
grad_clip_norm: The maximum norm for gradient clipping.
|
||||
grad_scaler: The GradScaler for automatic mixed-precision training.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
use_amp: A boolean indicating whether to use automatic mixed precision.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The updated MetricsTracker with new statistics for this step.
|
||||
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
@@ -108,6 +130,20 @@ def update_policy(
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
"""
|
||||
Main function to train a policy.
|
||||
|
||||
This function orchestrates the entire training pipeline, including:
|
||||
- Setting up logging, seeding, and device configuration.
|
||||
- Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
|
||||
- Handling resumption from a checkpoint.
|
||||
- Running the main training loop, which involves fetching data batches and calling `update_policy`.
|
||||
- Periodically logging metrics, saving model checkpoints, and evaluating the policy.
|
||||
- Pushing the final trained model to the Hugging Face Hub if configured.
|
||||
|
||||
Args:
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
"""
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
@@ -153,9 +189,11 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
|
||||
preprocessor.from_pretrained(
|
||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
postprocessor.from_pretrained(
|
||||
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
@@ -260,9 +298,11 @@ def train(cfg: TrainPipelineConfig):
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
env=eval_env,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
|
||||
+51
-20
@@ -55,19 +55,20 @@ import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor import EnvTransition, IdentityProcessorStep, RobotProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import (
|
||||
to_output_robot_action,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
action_to_transition,
|
||||
identity_transition,
|
||||
observation_to_transition,
|
||||
transition_to_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -105,9 +106,9 @@ class TeleoperateConfig:
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Optional processors for data transformation
|
||||
teleop_action_processor: RobotProcessor | None = None # runs after teleop
|
||||
robot_action_processor: RobotProcessor | None = None # runs before robot
|
||||
robot_observation_processor: RobotProcessor | None = None # runs after robot
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None # runs after teleop
|
||||
robot_action_processor: RobotProcessorPipeline | None = None # runs before robot
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None # runs after robot
|
||||
|
||||
|
||||
def teleop_loop(
|
||||
@@ -116,21 +117,47 @@ def teleop_loop(
|
||||
fps: int,
|
||||
display_data: bool = False,
|
||||
duration: float | None = None,
|
||||
teleop_action_processor: RobotProcessor | None = None,
|
||||
robot_action_processor: RobotProcessor | None = None,
|
||||
robot_observation_processor: RobotProcessor | None = None,
|
||||
teleop_action_processor: RobotProcessorPipeline[EnvTransition] | None = None,
|
||||
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None,
|
||||
):
|
||||
"""
|
||||
This function continuously reads actions from a teleoperation device, processes them through optional
|
||||
pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a
|
||||
specified frequency until a set duration is reached or it is manually interrupted.
|
||||
|
||||
Args:
|
||||
teleop: The teleoperator device instance providing control actions.
|
||||
robot: The robot instance being controlled.
|
||||
fps: The target frequency for the control loop in frames per second.
|
||||
display_data: If True, fetches robot observations and displays them in the console and Rerun.
|
||||
duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
|
||||
teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
|
||||
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
|
||||
robot_observation_processor: An optional pipeline to process raw observations from the robot.
|
||||
"""
|
||||
# Initialize processors with defaults if not provided
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_teleop_action, to_output=lambda tr: tr
|
||||
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||
teleop_action_processor
|
||||
or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=identity_transition
|
||||
)
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=to_output_robot_action, # type: ignore[arg-type]
|
||||
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] = (
|
||||
robot_action_processor
|
||||
or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=identity_transition,
|
||||
to_output=transition_to_action, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessor(
|
||||
steps=[IdentityProcessor()], to_transition=to_transition_robot_observation, to_output=lambda tr: tr
|
||||
robot_observation_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||
robot_observation_processor
|
||||
or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
)
|
||||
|
||||
# Reset processors
|
||||
@@ -161,7 +188,11 @@ def teleop_loop(
|
||||
obs = robot.get_observation()
|
||||
# Process robot observation through pipeline
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
log_rerun_data([obs_transition, teleop_transition])
|
||||
|
||||
log_rerun_data(
|
||||
observation=obs_transition.get(TransitionKey.OBSERVATION),
|
||||
action=teleop_transition.get(TransitionKey.ACTION),
|
||||
)
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
|
||||
|
||||
@@ -88,6 +88,7 @@ class KochLeader(Teleoperator):
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
@@ -98,7 +99,6 @@ class KochLeader(Teleoperator):
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.bus.disable_torque()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
|
||||
@@ -16,45 +16,53 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.processor import ActionProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapPhoneActionToRobotAction(ActionProcessor):
|
||||
class MapPhoneActionToRobotAction(ActionProcessorStep):
|
||||
"""
|
||||
Map calibrated phone pose (actions) to the inputs for robot actions
|
||||
Maps calibrated phone pose actions to standardized robot action inputs.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.phone.enabled": bool,
|
||||
"action.phone.pos": np.ndarray,
|
||||
"action.phone.rot": Rotation,
|
||||
"action.phone.raw_inputs": dict,
|
||||
}
|
||||
This processor step acts as a bridge between the phone teleoperator's output
|
||||
and the robot's expected action format. It remaps the phone's 6-DoF pose
|
||||
(position and rotation) to the robot's target end-effector pose, applying
|
||||
necessary axis inversions and swaps. It also interprets platform-specific
|
||||
button presses to generate a gripper command.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"action.gripper": float,
|
||||
}
|
||||
Attributes:
|
||||
platform: The operating system of the phone (iOS or Android), used
|
||||
to determine the correct button mappings for the gripper.
|
||||
"""
|
||||
|
||||
platform: PhoneOS
|
||||
_enabled_prev: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
"""
|
||||
Processes the phone action dictionary to create a robot action dictionary.
|
||||
|
||||
Args:
|
||||
act: The input action dictionary from the phone teleoperator.
|
||||
|
||||
Returns:
|
||||
A new action dictionary formatted for the robot controller.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'pos' or 'rot' keys are missing from the input action.
|
||||
"""
|
||||
# Pop them from the action
|
||||
enabled = bool(act.pop("action.phone.enabled", 0))
|
||||
pos = act.pop("action.phone.pos", None)
|
||||
rot = act.pop("action.phone.rot", None)
|
||||
inputs = act.pop("action.phone.raw_inputs", {})
|
||||
enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0))
|
||||
pos = act.pop(f"{ACTION}.phone.pos", None)
|
||||
rot = act.pop(f"{ACTION}.phone.rot", None)
|
||||
inputs = act.pop(f"{ACTION}.phone.raw_inputs", {})
|
||||
|
||||
if pos is None or rot is None:
|
||||
return act
|
||||
raise ValueError("pos and rot must be present in action")
|
||||
|
||||
rotvec = rot.as_rotvec() # Absolute orientation as rotvec
|
||||
|
||||
@@ -69,28 +77,30 @@ class MapPhoneActionToRobotAction(ActionProcessor):
|
||||
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
|
||||
|
||||
# For some actions we need to invert the axis
|
||||
act["action.enabled"] = enabled
|
||||
act["action.target_x"] = -pos[1] if enabled else 0.0
|
||||
act["action.target_y"] = pos[0] if enabled else 0.0
|
||||
act["action.target_z"] = pos[2] if enabled else 0.0
|
||||
act["action.target_wx"] = rotvec[1] if enabled else 0.0
|
||||
act["action.target_wy"] = rotvec[0] if enabled else 0.0
|
||||
act["action.target_wz"] = -rotvec[2] if enabled else 0.0
|
||||
act["action.gripper"] = gripper # Still send gripper action when disabled
|
||||
act[f"{ACTION}.enabled"] = enabled
|
||||
act[f"{ACTION}.target_x"] = -pos[1] if enabled else 0.0
|
||||
act[f"{ACTION}.target_y"] = pos[0] if enabled else 0.0
|
||||
act[f"{ACTION}.target_z"] = pos[2] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wx"] = rotvec[1] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wy"] = rotvec[0] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wz"] = -rotvec[2] if enabled else 0.0
|
||||
act[f"{ACTION}.gripper"] = gripper # Still send gripper action when disabled
|
||||
return act
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop("action.phone.enabled", None)
|
||||
features.pop("action.phone.pos", None)
|
||||
features.pop("action.phone.rot", None)
|
||||
features.pop("action.phone.raw_inputs", None)
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION].pop("phone.enabled", None)
|
||||
features[PipelineFeatureType.ACTION].pop("phone.pos", None)
|
||||
features[PipelineFeatureType.ACTION].pop("phone.rot", None)
|
||||
features[PipelineFeatureType.ACTION].pop("phone.raw_inputs", None)
|
||||
|
||||
features["action.enabled"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.gripper"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[PipelineFeatureType.ACTION]["enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
@@ -24,12 +24,12 @@ import time
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -101,32 +101,56 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
|
||||
|
||||
pos, rot = self._wait_for_capture_trigger()
|
||||
self._calib_pos = pos.copy()
|
||||
self._calib_rot_inv = rot.inv()
|
||||
position, rotation = self._wait_for_capture_trigger()
|
||||
self._calib_pos = position.copy()
|
||||
self._calib_rot_inv = rotation.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
"""
|
||||
Blocks execution until the calibration trigger is detected from the iOS device.
|
||||
|
||||
This method enters a loop, continuously reading the phone's state. It waits for the user to press
|
||||
and hold the 'B1' button in the HEBI Mobile I/O app. Once B1 is pressed, the loop breaks and
|
||||
returns the phone's pose at that exact moment.
|
||||
|
||||
Returns:
|
||||
A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
|
||||
moment the trigger was activated.
|
||||
"""
|
||||
while True:
|
||||
ok, pos, rot, pose = self._read_current_pose()
|
||||
if not ok:
|
||||
has_pose, position, rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
io = getattr(pose, "io", None)
|
||||
b = getattr(io, "b", None) if io is not None else None
|
||||
b1 = False
|
||||
if b is not None:
|
||||
b1 = bool(b.get_int(1))
|
||||
if b1:
|
||||
return pos, rot
|
||||
io = getattr(fb_pose, "io", None)
|
||||
button_b = getattr(io, "b", None) if io is not None else None
|
||||
button_b1_pressed = False
|
||||
if button_b is not None:
|
||||
button_b1_pressed = bool(button_b.get_int(1))
|
||||
if button_b1_pressed:
|
||||
return position, rotation
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
"""
|
||||
Reads the instantaneous 6-DoF pose from the connected iOS device via the HEBI SDK.
|
||||
|
||||
This method fetches the latest feedback packet from the HEBI group, extracts the ARKit
|
||||
position and orientation, and converts them into a standard format. It also applies a
|
||||
configured camera offset to adjust the pose from the camera's frame to the phone's
|
||||
physical frame.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A boolean indicating if a valid pose was successfully read.
|
||||
- The 3D position as a NumPy array, or None if not available.
|
||||
- The orientation as a `Rotation` object, or None if not available.
|
||||
- The raw HEBI feedback object for accessing other data like button presses.
|
||||
"""
|
||||
fbk = self._group.get_next_feedback()
|
||||
pose = fbk[0]
|
||||
ar_pos = getattr(pose, "ar_position", None)
|
||||
@@ -141,13 +165,13 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
return True, pos, rot, pose
|
||||
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
io = getattr(pose, "io", None)
|
||||
io = getattr(fb_pose, "io", None)
|
||||
if io is not None:
|
||||
bank_a, bank_b = io.a, io.b
|
||||
if bank_a:
|
||||
@@ -165,11 +189,11 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_pos)
|
||||
self._reapply_position_calibration(raw_position)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rot
|
||||
pos_cal = self._calib_rot_inv.apply(raw_position - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rotation
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
@@ -229,7 +253,18 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
"""
|
||||
Blocks execution until the calibration trigger is detected from the Android device.
|
||||
|
||||
This method enters a loop, continuously checking the latest message received from the WebXR
|
||||
session. It waits for the user to touch and move their finger on the screen, which generates
|
||||
a `move` event. Once this event is detected, the loop breaks and returns the phone's current
|
||||
pose.
|
||||
|
||||
Returns:
|
||||
A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
|
||||
moment the trigger was activated.
|
||||
"""
|
||||
while True:
|
||||
with self._android_lock:
|
||||
msg = self._latest_message or {}
|
||||
@@ -242,6 +277,20 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
"""
|
||||
Reads the latest 6-DoF pose received from the Android device's WebXR session.
|
||||
|
||||
This method accesses the most recent pose data stored by the `_android_callback`. It uses a
|
||||
thread lock to safely read the shared `_latest_pose` variable. The pose, a 4x4 matrix, is
|
||||
then decomposed into position and rotation, and the configured camera offset is applied.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A boolean indicating if a valid pose was available.
|
||||
- The 3D position as a NumPy array, or None if no pose has been received yet.
|
||||
- The orientation as a `Rotation` object, or None if no pose has been received.
|
||||
- The raw 4x4 pose matrix as received from the teleop stream.
|
||||
"""
|
||||
with self._android_lock:
|
||||
if self._latest_pose is None:
|
||||
return False, None, None, None
|
||||
@@ -252,6 +301,19 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
return True, pos, rot, pose
|
||||
|
||||
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
|
||||
"""
|
||||
Callback function to handle incoming data from the Android teleop stream.
|
||||
|
||||
This method is executed by the `teleop` package's subscriber thread whenever a new
|
||||
pose and message are received from the WebXR session on the Android phone. It updates
|
||||
the internal state (`_latest_pose` and `_latest_message`) with the new data.
|
||||
A thread lock is used to ensure that these shared variables are updated atomically,
|
||||
preventing race conditions with the main thread that reads them.
|
||||
|
||||
Args:
|
||||
pose: A 4x4 NumPy array representing the phone's transformation matrix.
|
||||
message: A dictionary containing additional data, such as button presses or touch events.
|
||||
"""
|
||||
with self._android_lock:
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
|
||||
from .reachy2_teleoperator import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Teleoperator,
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("reachy2_teleoperator")
|
||||
@dataclass
|
||||
class Reachy2TeleoperatorConfig(TeleoperatorConfig):
|
||||
# IP address of the Reachy 2 robot used as teleoperator
|
||||
ip_address: str | None = "localhost"
|
||||
|
||||
# Whether to use the present position of the joints as actions
|
||||
# if False, the goal position of the joints will be used
|
||||
use_present_position: bool = False
|
||||
|
||||
# Which parts of the robot to use
|
||||
with_mobile_base: bool = True
|
||||
with_l_arm: bool = True
|
||||
with_r_arm: bool = True
|
||||
with_neck: bool = True
|
||||
with_antennas: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if not (
|
||||
self.with_mobile_base
|
||||
or self.with_l_arm
|
||||
or self.with_r_arm
|
||||
or self.with_neck
|
||||
or self.with_antennas
|
||||
):
|
||||
raise ValueError(
|
||||
"No Reachy2Teleoperator part used.\n"
|
||||
"At least one part of the robot must be set to True "
|
||||
"(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)"
|
||||
)
|
||||
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_NECK_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
}
|
||||
|
||||
REACHY2_ANTENNAS_JOINTS = {
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
}
|
||||
|
||||
REACHY2_R_ARM_JOINTS = {
|
||||
"r_shoulder_pitch.pos": "r_arm.shoulder.pitch",
|
||||
"r_shoulder_roll.pos": "r_arm.shoulder.roll",
|
||||
"r_elbow_yaw.pos": "r_arm.elbow.yaw",
|
||||
"r_elbow_pitch.pos": "r_arm.elbow.pitch",
|
||||
"r_wrist_roll.pos": "r_arm.wrist.roll",
|
||||
"r_wrist_pitch.pos": "r_arm.wrist.pitch",
|
||||
"r_wrist_yaw.pos": "r_arm.wrist.yaw",
|
||||
"r_gripper.pos": "r_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_L_ARM_JOINTS = {
|
||||
"l_shoulder_pitch.pos": "l_arm.shoulder.pitch",
|
||||
"l_shoulder_roll.pos": "l_arm.shoulder.roll",
|
||||
"l_elbow_yaw.pos": "l_arm.elbow.yaw",
|
||||
"l_elbow_pitch.pos": "l_arm.elbow.pitch",
|
||||
"l_wrist_roll.pos": "l_arm.wrist.roll",
|
||||
"l_wrist_pitch.pos": "l_arm.wrist.pitch",
|
||||
"l_wrist_yaw.pos": "l_arm.wrist.yaw",
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "vx",
|
||||
"mobile_base.vy": "vy",
|
||||
"mobile_base.vtheta": "vtheta",
|
||||
}
|
||||
|
||||
|
||||
class Reachy2Teleoperator(Teleoperator):
|
||||
"""
|
||||
[Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics.
|
||||
"""
|
||||
|
||||
config_class = Reachy2TeleoperatorConfig
|
||||
name = "reachy2_specific"
|
||||
|
||||
def __init__(self, config: Reachy2TeleoperatorConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.reachy: None | ReachySDK = None
|
||||
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
def _generate_joints_dict(self) -> dict[str, str]:
|
||||
joints = {}
|
||||
if self.config.with_neck:
|
||||
joints.update(REACHY2_NECK_JOINTS)
|
||||
if self.config.with_l_arm:
|
||||
joints.update(REACHY2_L_ARM_JOINTS)
|
||||
if self.config.with_r_arm:
|
||||
joints.update(REACHY2_R_ARM_JOINTS)
|
||||
if self.config.with_antennas:
|
||||
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||
return joints
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
self.joints_dict.keys(),
|
||||
float,
|
||||
),
|
||||
**dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
),
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.joints_dict.keys(), float)
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
|
||||
if self.reachy and self.is_connected:
|
||||
if self.config.use_present_position:
|
||||
joint_action = {
|
||||
k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()
|
||||
}
|
||||
else:
|
||||
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
|
||||
|
||||
if not self.config.with_mobile_base:
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return joint_action
|
||||
|
||||
if self.config.use_present_position:
|
||||
vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
else:
|
||||
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return {**joint_action, **vel_action}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self.reachy and self.is_connected:
|
||||
self.reachy.disconnect()
|
||||
@@ -77,5 +77,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
from .bi_so100_leader import BiSO100Leader
|
||||
|
||||
return BiSO100Leader(config)
|
||||
elif config.type == "reachy2_teleoperator":
|
||||
from .reachy2_teleoperator import Reachy2Teleoperator
|
||||
|
||||
return Reachy2Teleoperator(config)
|
||||
else:
|
||||
raise ValueError(config.type)
|
||||
|
||||
@@ -31,11 +31,25 @@ from termcolor import colored
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor, TransitionKey
|
||||
from lerobot.processor import PolicyProcessorPipeline, TransitionKey
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
"""
|
||||
Logs performance metrics for a single step of the robot control loop.
|
||||
|
||||
This function formats and prints a single line of log information, including episode/frame counters,
|
||||
total loop time (dt), and detailed timings for various robot and camera operations. It can also
|
||||
highlight performance drops in yellow if the actual FPS is lower than the target FPS.
|
||||
|
||||
Args:
|
||||
robot: The `Robot` instance, used to access its internal logs for detailed timings.
|
||||
dt_s: The total duration of the control loop step in seconds.
|
||||
episode_index: The index of the current episode.
|
||||
frame_index: The index of the current frame within the episode.
|
||||
fps: The target frames per second, used to check for performance degradation.
|
||||
"""
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
@@ -81,7 +95,16 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
|
||||
@cache
|
||||
def is_headless():
|
||||
"""Detects if python is running without a monitor."""
|
||||
"""
|
||||
Detects if the Python script is running in a headless environment (e.g., without a display).
|
||||
|
||||
This function attempts to import `pynput`, a library that requires a graphical environment.
|
||||
If the import fails, it assumes the environment is headless. The result is cached to avoid
|
||||
re-running the check.
|
||||
|
||||
Returns:
|
||||
True if the environment is determined to be headless, False otherwise.
|
||||
"""
|
||||
try:
|
||||
import pynput # noqa
|
||||
|
||||
@@ -102,12 +125,35 @@ def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
preprocessor: RobotProcessor,
|
||||
postprocessor: RobotProcessor,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
):
|
||||
"""
|
||||
Performs a single-step inference to predict a robot action from an observation.
|
||||
|
||||
This function encapsulates the full inference pipeline:
|
||||
1. Prepares the observation by converting it to PyTorch tensors and adding a batch dimension.
|
||||
2. Runs the preprocessor pipeline on the observation.
|
||||
3. Feeds the processed observation to the policy to get a raw action.
|
||||
4. Runs the postprocessor pipeline on the raw action.
|
||||
5. Formats the final action by removing the batch dimension and moving it to the CPU.
|
||||
|
||||
Args:
|
||||
observation: A dictionary of NumPy arrays representing the robot's current observation.
|
||||
policy: The `PreTrainedPolicy` model to use for action prediction.
|
||||
device: The `torch.device` (e.g., 'cuda' or 'cpu') to run inference on.
|
||||
preprocessor: The `PolicyProcessorPipeline` for preprocessing observations.
|
||||
postprocessor: The `PolicyProcessorPipeline` for postprocessing actions.
|
||||
use_amp: A boolean to enable/disable Automatic Mixed Precision for CUDA inference.
|
||||
task: An optional string identifier for the task.
|
||||
robot_type: An optional string identifier for the robot type.
|
||||
|
||||
Returns:
|
||||
A `torch.Tensor` containing the predicted action, ready for the robot.
|
||||
"""
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
@@ -143,6 +189,18 @@ def predict_action(
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""
|
||||
Initializes a non-blocking keyboard listener for real-time user interaction.
|
||||
|
||||
This function sets up a listener for specific keys (right arrow, left arrow, escape) to control
|
||||
the program flow during execution, such as stopping recording or exiting loops. It gracefully
|
||||
handles headless environments where keyboard listening is not possible.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The `pynput.keyboard.Listener` instance, or `None` if in a headless environment.
|
||||
- A dictionary of event flags (e.g., `exit_early`) that are set by key presses.
|
||||
"""
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
@@ -184,6 +242,19 @@ def init_keyboard_listener():
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
"""
|
||||
Validates the dataset repository name against the presence of a policy configuration.
|
||||
|
||||
This function enforces a naming convention: a dataset repository ID should start with "eval_"
|
||||
if and only if a policy configuration is provided for evaluation purposes.
|
||||
|
||||
Args:
|
||||
repo_id: The Hugging Face Hub repository ID of the dataset.
|
||||
policy_cfg: The configuration object for the policy, or `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the naming convention is violated.
|
||||
"""
|
||||
_, dataset_name = repo_id.split("/")
|
||||
# either repo_id doesnt start with "eval_" and there is no policy
|
||||
# or repo_id starts with "eval_" and there is a policy
|
||||
@@ -204,6 +275,21 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, features: dict
|
||||
) -> None:
|
||||
"""
|
||||
Checks if a dataset's metadata is compatible with the current robot and recording setup.
|
||||
|
||||
This function compares key metadata fields (`robot_type`, `fps`, and `features`) from the
|
||||
dataset against the current configuration to ensure that appended data will be consistent.
|
||||
|
||||
Args:
|
||||
dataset: The `LeRobotDataset` instance to check.
|
||||
robot: The `Robot` instance representing the current hardware setup.
|
||||
fps: The current recording frequency (frames per second).
|
||||
features: The dictionary of features for the current recording session.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the checked metadata fields do not match.
|
||||
"""
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom rotation utilities to replace scipy.spatial.transform.Rotation."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Rotation:
|
||||
"""
|
||||
Custom rotation class that provides a subset of scipy.spatial.transform.Rotation functionality.
|
||||
|
||||
Supports conversions between rotation vectors, rotation matrices, and quaternions.
|
||||
"""
|
||||
|
||||
def __init__(self, quat: np.ndarray) -> None:
|
||||
"""Initialize rotation from quaternion [x, y, z, w]."""
|
||||
self._quat = np.asarray(quat, dtype=float)
|
||||
# Normalize quaternion
|
||||
norm = np.linalg.norm(self._quat)
|
||||
if norm > 0:
|
||||
self._quat = self._quat / norm
|
||||
|
||||
@classmethod
|
||||
def from_rotvec(cls, rotvec: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from rotation vector using Rodrigues' formula.
|
||||
|
||||
Args:
|
||||
rotvec: Rotation vector [x, y, z] where magnitude is angle in radians
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
rotvec = np.asarray(rotvec, dtype=float)
|
||||
angle = np.linalg.norm(rotvec)
|
||||
|
||||
if angle < 1e-8:
|
||||
# For very small angles, use identity quaternion
|
||||
quat = np.array([0.0, 0.0, 0.0, 1.0])
|
||||
else:
|
||||
axis = rotvec / angle
|
||||
half_angle = angle / 2.0
|
||||
sin_half = np.sin(half_angle)
|
||||
cos_half = np.cos(half_angle)
|
||||
|
||||
# Quaternion [x, y, z, w]
|
||||
quat = np.array([axis[0] * sin_half, axis[1] * sin_half, axis[2] * sin_half, cos_half])
|
||||
|
||||
return cls(quat)
|
||||
|
||||
@classmethod
|
||||
def from_matrix(cls, matrix: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from 3x3 rotation matrix.
|
||||
|
||||
Args:
|
||||
matrix: 3x3 rotation matrix
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
matrix = np.asarray(matrix, dtype=float)
|
||||
|
||||
# Shepherd's method for converting rotation matrix to quaternion
|
||||
trace = np.trace(matrix)
|
||||
|
||||
if trace > 0:
|
||||
s = np.sqrt(trace + 1.0) * 2 # s = 4 * qw
|
||||
qw = 0.25 * s
|
||||
qx = (matrix[2, 1] - matrix[1, 2]) / s
|
||||
qy = (matrix[0, 2] - matrix[2, 0]) / s
|
||||
qz = (matrix[1, 0] - matrix[0, 1]) / s
|
||||
elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]:
|
||||
s = np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 # s = 4 * qx
|
||||
qw = (matrix[2, 1] - matrix[1, 2]) / s
|
||||
qx = 0.25 * s
|
||||
qy = (matrix[0, 1] + matrix[1, 0]) / s
|
||||
qz = (matrix[0, 2] + matrix[2, 0]) / s
|
||||
elif matrix[1, 1] > matrix[2, 2]:
|
||||
s = np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 # s = 4 * qy
|
||||
qw = (matrix[0, 2] - matrix[2, 0]) / s
|
||||
qx = (matrix[0, 1] + matrix[1, 0]) / s
|
||||
qy = 0.25 * s
|
||||
qz = (matrix[1, 2] + matrix[2, 1]) / s
|
||||
else:
|
||||
s = np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 # s = 4 * qz
|
||||
qw = (matrix[1, 0] - matrix[0, 1]) / s
|
||||
qx = (matrix[0, 2] + matrix[2, 0]) / s
|
||||
qy = (matrix[1, 2] + matrix[2, 1]) / s
|
||||
qz = 0.25 * s
|
||||
|
||||
quat = np.array([qx, qy, qz, qw])
|
||||
return cls(quat)
|
||||
|
||||
@classmethod
|
||||
def from_quat(cls, quat: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from quaternion.
|
||||
|
||||
Args:
|
||||
quat: Quaternion [x, y, z, w] or [w, x, y, z] (specify convention in docstring)
|
||||
This implementation expects [x, y, z, w] format
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
return cls(quat)
|
||||
|
||||
def as_matrix(self) -> np.ndarray:
|
||||
"""
|
||||
Convert rotation to 3x3 rotation matrix.
|
||||
|
||||
Returns:
|
||||
3x3 rotation matrix
|
||||
"""
|
||||
qx, qy, qz, qw = self._quat
|
||||
|
||||
# Compute rotation matrix from quaternion
|
||||
return np.array(
|
||||
[
|
||||
[1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)],
|
||||
[2 * (qx * qy + qz * qw), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qx * qw)],
|
||||
[2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx * qx + qy * qy)],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
def as_rotvec(self) -> np.ndarray:
|
||||
"""
|
||||
Convert rotation to rotation vector.
|
||||
|
||||
Returns:
|
||||
Rotation vector [x, y, z] where magnitude is angle in radians
|
||||
"""
|
||||
qx, qy, qz, qw = self._quat
|
||||
|
||||
# Ensure qw is positive for unique representation
|
||||
if qw < 0:
|
||||
qx, qy, qz, qw = -qx, -qy, -qz, -qw
|
||||
|
||||
# Compute angle and axis
|
||||
angle = 2.0 * np.arccos(np.clip(abs(qw), 0.0, 1.0))
|
||||
sin_half_angle = np.sqrt(1.0 - qw * qw)
|
||||
|
||||
if sin_half_angle < 1e-8:
|
||||
# For very small angles, use linearization: rotvec ≈ 2 * [qx, qy, qz]
|
||||
return 2.0 * np.array([qx, qy, qz])
|
||||
|
||||
# Extract axis and scale by angle
|
||||
axis = np.array([qx, qy, qz]) / sin_half_angle
|
||||
return angle * axis
|
||||
|
||||
def as_quat(self) -> np.ndarray:
|
||||
"""
|
||||
Get quaternion representation.
|
||||
|
||||
Returns:
|
||||
Quaternion [x, y, z, w]
|
||||
"""
|
||||
return self._quat.copy()
|
||||
@@ -32,7 +32,7 @@ from lerobot.datasets.utils import load_json, write_json
|
||||
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
|
||||
|
||||
@@ -75,8 +75,8 @@ def save_checkpoint(
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: RobotProcessor | None = None,
|
||||
postprocessor: RobotProcessor | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
|
||||
@@ -19,8 +19,6 @@ from typing import Any
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop."""
|
||||
@@ -33,85 +31,67 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
|
||||
def _is_scalar(x):
|
||||
return (
|
||||
isinstance(x, numbers.Real)
|
||||
isinstance(x, float)
|
||||
or isinstance(x, numbers.Real)
|
||||
or isinstance(x, (np.integer, np.floating))
|
||||
or (isinstance(x, np.ndarray) and x.ndim == 0)
|
||||
)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None,
|
||||
*,
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
items = data if isinstance(data, list) else ([data] if data is not None else [])
|
||||
"""
|
||||
Logs observation and action data to Rerun for real-time visualization.
|
||||
|
||||
obs = {} if observation is None else dict(observation)
|
||||
act = {} if action is None else dict(action)
|
||||
This function iterates through the provided observation and action dictionaries and sends their contents
|
||||
to the Rerun viewer. It handles different data types appropriately:
|
||||
- Scalar values (floats, ints) are logged as `rr.Scalar`.
|
||||
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
|
||||
from CHW to HWC format and logged as `rr.Image`.
|
||||
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
|
||||
- Other multi-dimensional arrays are flattened and logged as individual scalars.
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
Keys are automatically namespaced with "observation." or "action." if not already present.
|
||||
|
||||
if any(isinstance(k, TransitionKey) for k in item.keys()):
|
||||
o = item.get(TransitionKey.OBSERVATION) or {}
|
||||
a = item.get(TransitionKey.ACTION) or {}
|
||||
if isinstance(o, dict):
|
||||
obs.update(o)
|
||||
if isinstance(a, dict):
|
||||
act.update(a)
|
||||
continue
|
||||
Args:
|
||||
observation: An optional dictionary containing observation data to log.
|
||||
action: An optional dictionary containing action data to log.
|
||||
"""
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
|
||||
keys = list(item.keys())
|
||||
has_obs = any(str(k).startswith("observation.") for k in keys)
|
||||
has_act = any(str(k).startswith("action.") for k in keys)
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
rr.log(key, rr.Image(arr), static=True)
|
||||
|
||||
if has_obs or has_act:
|
||||
if has_obs:
|
||||
obs.update(item)
|
||||
if has_act:
|
||||
act.update(item)
|
||||
else:
|
||||
# No prefixes: assume first is observation, second is action, others are observation
|
||||
if idx == 0:
|
||||
obs.update(item)
|
||||
elif idx == 1:
|
||||
act.update(item)
|
||||
else:
|
||||
obs.update(item)
|
||||
if action:
|
||||
for k, v in action.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("action.") else f"action.{k}"
|
||||
|
||||
for k, v in obs.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
rr.log(key, rr.Image(arr), static=True)
|
||||
|
||||
for k, v in act.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("action.") else f"action.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
|
||||
PARAMS = [
|
||||
("teleop", "left"),
|
||||
("teleop", "right"),
|
||||
("depth", "rgb"),
|
||||
# ("depth", "depth"), # Depth camera is not available yet
|
||||
]
|
||||
|
||||
|
||||
def _make_cam_manager_mock():
|
||||
c = MagicMock(name="CameraManagerMock")
|
||||
|
||||
teleop = MagicMock(name="TeleopCam")
|
||||
teleop.width = 640
|
||||
teleop.height = 480
|
||||
teleop.get_frame = MagicMock(
|
||||
side_effect=lambda *_, **__: (
|
||||
np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
time.time(),
|
||||
)
|
||||
)
|
||||
|
||||
depth = MagicMock(name="DepthCam")
|
||||
depth.width = 640
|
||||
depth.height = 480
|
||||
depth.get_frame = MagicMock(
|
||||
side_effect=lambda *_, **__: (
|
||||
np.zeros((480, 640, 3), dtype=np.uint8),
|
||||
time.time(),
|
||||
)
|
||||
)
|
||||
|
||||
c.is_connected.return_value = True
|
||||
c.teleop = teleop
|
||||
c.depth = depth
|
||||
|
||||
def _connect():
|
||||
c.teleop = teleop
|
||||
c.depth = depth
|
||||
c.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
c.teleop = None
|
||||
c.depth = None
|
||||
c.is_connected.return_value = False
|
||||
|
||||
c.connect = MagicMock(side_effect=_connect)
|
||||
c.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
# Mock methods
|
||||
c.initialize_cameras = MagicMock()
|
||||
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=PARAMS,
|
||||
# ids=["teleop-left", "teleop-right", "torso-rgb", "torso-depth"],
|
||||
ids=["teleop-left", "teleop-right", "torso-rgb"],
|
||||
)
|
||||
def camera(request):
|
||||
name, image_type = request.param
|
||||
with (
|
||||
patch(
|
||||
"lerobot.cameras.reachy2_camera.reachy2_camera.CameraManager",
|
||||
side_effect=lambda *a, **k: _make_cam_manager_mock(),
|
||||
),
|
||||
):
|
||||
config = Reachy2CameraConfig(name=name, image_type=image_type)
|
||||
cam = Reachy2Camera(config)
|
||||
yield cam
|
||||
if cam.is_connected:
|
||||
cam.disconnect()
|
||||
|
||||
|
||||
def test_connect(camera):
|
||||
camera.connect()
|
||||
assert camera.is_connected
|
||||
camera.cam_manager.initialize_cameras.assert_called_once()
|
||||
|
||||
|
||||
def test_read(camera):
|
||||
camera.connect()
|
||||
|
||||
img = camera.read()
|
||||
if camera.config.name == "teleop":
|
||||
camera.cam_manager.teleop.get_frame.assert_called_once()
|
||||
elif camera.config.name == "depth":
|
||||
camera.cam_manager.depth.get_frame.assert_called_once()
|
||||
assert isinstance(img, np.ndarray)
|
||||
assert img.shape == (480, 640, 3)
|
||||
|
||||
|
||||
def test_disconnect(camera):
|
||||
camera.connect()
|
||||
|
||||
camera.disconnect()
|
||||
assert not camera.is_connected
|
||||
|
||||
|
||||
def test_async_read(camera):
|
||||
camera.connect()
|
||||
try:
|
||||
img = camera.async_read()
|
||||
|
||||
assert camera.thread is not None
|
||||
assert camera.thread.is_alive()
|
||||
assert isinstance(img, np.ndarray)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read_timeout(camera):
|
||||
camera.connect()
|
||||
try:
|
||||
with pytest.raises(TimeoutError):
|
||||
camera.async_read(timeout_ms=0)
|
||||
finally:
|
||||
if camera.is_connected:
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_read_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.read()
|
||||
|
||||
|
||||
def test_disconnect_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
def test_async_read_before_connect(camera):
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
_ = camera.async_read()
|
||||
|
||||
|
||||
def test_wrong_camera_name():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="wrong-name", image_type="left")
|
||||
|
||||
|
||||
def test_wrong_image_type():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="teleop", image_type="rgb")
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="depth", image_type="left")
|
||||
|
||||
|
||||
def test_wrong_color_mode():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2CameraConfig(name="teleop", image_type="left", color_mode="wrong-color")
|
||||
+7
-4
@@ -19,7 +19,7 @@ import traceback
|
||||
import pytest
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from tests.utils import DEVICE
|
||||
|
||||
# Import fixture modules as plugins
|
||||
@@ -28,6 +28,7 @@ pytest_plugins = [
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
"tests.fixtures.optimizers",
|
||||
"tests.plugins.reachy2_sdk",
|
||||
]
|
||||
|
||||
|
||||
@@ -82,7 +83,9 @@ def policy_feature_factory():
|
||||
return _pf
|
||||
|
||||
|
||||
def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None:
|
||||
def assert_contract_is_typed(features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> None:
|
||||
assert isinstance(features, dict)
|
||||
assert all(isinstance(k, str) for k in features.keys())
|
||||
assert all(isinstance(v, PolicyFeature) for v in features.values())
|
||||
assert all(isinstance(k, PipelineFeatureType) for k in features.keys())
|
||||
assert all(isinstance(v, dict) for v in features.values())
|
||||
assert all(all(isinstance(nk, str) for nk in v.keys()) for v in features.values())
|
||||
assert all(all(isinstance(nv, PolicyFeature) for nv in v.values()) for v in features.values())
|
||||
|
||||
@@ -20,7 +20,7 @@ from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch, merge_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -72,7 +72,7 @@ def test_merge_simple_vectors():
|
||||
}
|
||||
}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
|
||||
assert "action" in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
@@ -87,7 +87,7 @@ def test_merge_multiple_groups_order_and_dedup():
|
||||
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
|
||||
out = merge_features(g1, g2, g3)
|
||||
out = combine_feature_dicts(g1, g2, g3)
|
||||
|
||||
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
||||
assert out["action"]["shape"] == (4,)
|
||||
@@ -110,7 +110,7 @@ def test_non_vector_last_wins_for_images():
|
||||
}
|
||||
}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||
assert out["observation.images.front"]["dtype"] == "image"
|
||||
|
||||
@@ -120,13 +120,13 @@ def test_dtype_mismatch_raises():
|
||||
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
|
||||
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
||||
_ = merge_features(g1, g2)
|
||||
_ = combine_feature_dicts(g1, g2)
|
||||
|
||||
|
||||
def test_non_dict_passthrough_last_wins():
|
||||
g1 = {"misc": 123}
|
||||
g2 = {"misc": 456}
|
||||
|
||||
out = merge_features(g1, g2)
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
# For non-dict entries the last one wins
|
||||
assert out["misc"] == 456
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def _install_reachy2_sdk_stub():
|
||||
sdk = types.ModuleType("reachy2_sdk")
|
||||
sdk.__path__ = []
|
||||
sdk.ReachySDK = MagicMock(name="ReachySDK")
|
||||
|
||||
media = types.ModuleType("reachy2_sdk.media")
|
||||
media.__path__ = []
|
||||
camera = types.ModuleType("reachy2_sdk.media.camera")
|
||||
camera.CameraView = MagicMock(name="CameraView")
|
||||
camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager")
|
||||
camera_manager.CameraManager = MagicMock(name="CameraManager")
|
||||
|
||||
sdk.media = media
|
||||
media.camera = camera
|
||||
media.camera_manager = camera_manager
|
||||
|
||||
# Register in sys.modules
|
||||
sys.modules.setdefault("reachy2_sdk", sdk)
|
||||
sys.modules.setdefault("reachy2_sdk.media", media)
|
||||
sys.modules.setdefault("reachy2_sdk.media.camera", camera)
|
||||
sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager)
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
_install_reachy2_sdk_stub()
|
||||
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
@@ -81,20 +81,20 @@ def test_make_act_processor_basic():
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_act_processor_normalization():
|
||||
@@ -250,7 +250,7 @@ def test_act_processor_save_and_load():
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -303,11 +303,22 @@ def test_act_processor_mixed_precision():
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessor):
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
@@ -353,3 +364,59 @@ def test_act_processor_batch_consistency():
|
||||
processed_batched = preprocessor(transition_batched)
|
||||
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
|
||||
assert processed_batched[TransitionKey.ACTION].shape[0] == 8
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import (
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
_default_batch_to_transition,
|
||||
_default_transition_to_batch,
|
||||
)
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import batch_to_transition, transition_to_batch
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
@@ -24,7 +20,7 @@ def _dummy_batch():
|
||||
|
||||
def test_observation_grouping_roundtrip():
|
||||
"""Test that observation.* keys are properly grouped and ungrouped."""
|
||||
proc = RobotProcessor([])
|
||||
proc = DataProcessorPipeline([])
|
||||
batch_in = _dummy_batch()
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
@@ -48,7 +44,7 @@ def test_observation_grouping_roundtrip():
|
||||
|
||||
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
@@ -60,7 +56,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
"info": {"episode": 42},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
@@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that _default_transition_to_batch correctly flattens observation dict."""
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
@@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
batch = _default_transition_to_batch(transition)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that observation.* keys are flattened back to batch
|
||||
assert "observation.image.top" in batch
|
||||
@@ -134,7 +130,7 @@ def test_no_observation_keys():
|
||||
"info": {"test": "no_obs"},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Observation should be None when no observation.* keys
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
@@ -147,7 +143,7 @@ def test_no_observation_keys():
|
||||
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
@@ -159,7 +155,7 @@ def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
@@ -173,7 +169,7 @@ def test_minimal_batch():
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
@@ -186,7 +182,7 @@ def test_empty_batch():
|
||||
"""Test behavior with empty batch."""
|
||||
batch = {}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# All fields should have defaults
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
@@ -198,7 +194,7 @@ def test_empty_batch():
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] is None
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
@@ -219,8 +215,8 @@ def test_complex_nested_observation():
|
||||
"info": {"episode_length": 200, "success": True},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
transition = batch_to_transition(batch)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
|
||||
# Check that all observation keys are preserved
|
||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||
@@ -254,7 +250,7 @@ def test_custom_converter():
|
||||
|
||||
def to_tr(batch):
|
||||
# Custom converter that modifies the reward
|
||||
tr = _default_batch_to_transition(batch)
|
||||
tr = batch_to_transition(batch)
|
||||
# Double the reward
|
||||
reward = tr.get(TransitionKey.REWARD, 0.0)
|
||||
new_tr = tr.copy()
|
||||
@@ -262,10 +258,10 @@ def test_custom_converter():
|
||||
return new_tr
|
||||
|
||||
def to_batch(tr):
|
||||
batch = _default_transition_to_batch(tr)
|
||||
batch = transition_to_batch(tr)
|
||||
return batch
|
||||
|
||||
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
|
||||
@@ -22,9 +22,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor import ProcessorStepRegistry, RobotProcessor
|
||||
from lerobot.processor.batch_processor import ToBatchProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(
|
||||
@@ -44,7 +47,7 @@ def create_transition(
|
||||
|
||||
def test_state_1d_to_2d():
|
||||
"""Test that 1D state tensors get unsqueezed to 2D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test observation.state
|
||||
state_1d = torch.randn(7)
|
||||
@@ -60,7 +63,7 @@ def test_state_1d_to_2d():
|
||||
|
||||
def test_env_state_1d_to_2d():
|
||||
"""Test that 1D environment state tensors get unsqueezed to 2D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test observation.environment_state
|
||||
env_state_1d = torch.randn(10)
|
||||
@@ -76,7 +79,7 @@ def test_env_state_1d_to_2d():
|
||||
|
||||
def test_image_3d_to_4d():
|
||||
"""Test that 3D image tensors get unsqueezed to 4D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test observation.image
|
||||
image_3d = torch.randn(224, 224, 3)
|
||||
@@ -92,7 +95,7 @@ def test_image_3d_to_4d():
|
||||
|
||||
def test_multiple_images_3d_to_4d():
|
||||
"""Test that 3D image tensors in observation.images.* get unsqueezed to 4D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test observation.images.camera1 and observation.images.camera2
|
||||
image1_3d = torch.randn(64, 64, 3)
|
||||
@@ -117,7 +120,7 @@ def test_multiple_images_3d_to_4d():
|
||||
|
||||
def test_already_batched_tensors_unchanged():
|
||||
"""Test that already batched tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create already batched tensors
|
||||
state_2d = torch.randn(1, 7)
|
||||
@@ -143,7 +146,7 @@ def test_already_batched_tensors_unchanged():
|
||||
|
||||
def test_higher_dimensional_tensors_unchanged():
|
||||
"""Test that tensors with more dimensions than expected remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create tensors with more dimensions
|
||||
state_3d = torch.randn(2, 7, 5) # More than 1D
|
||||
@@ -166,7 +169,7 @@ def test_higher_dimensional_tensors_unchanged():
|
||||
|
||||
def test_non_tensor_values_unchanged():
|
||||
"""Test that non-tensor values in observations remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
observation = {
|
||||
OBS_STATE: [1, 2, 3], # List, not tensor
|
||||
@@ -189,7 +192,7 @@ def test_non_tensor_values_unchanged():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor handles None observation gracefully."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(observation=None)
|
||||
result = processor(transition)
|
||||
@@ -199,7 +202,7 @@ def test_none_observation():
|
||||
|
||||
def test_empty_observation():
|
||||
"""Test processor handles empty observation dict."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
observation = {}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -211,7 +214,7 @@ def test_empty_observation():
|
||||
|
||||
def test_mixed_observation():
|
||||
"""Test processor with mixed observation containing various types and dimensions."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
state_1d = torch.randn(5)
|
||||
env_state_2d = torch.randn(1, 8) # Already batched
|
||||
@@ -243,9 +246,9 @@ def test_mixed_observation():
|
||||
|
||||
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test ToBatchProcessor integration with RobotProcessor."""
|
||||
to_batch_processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
"""Test AddBatchDimensionProcessorStep integration with RobotProcessor."""
|
||||
to_batch_processor = AddBatchDimensionProcessorStep()
|
||||
pipeline = DataProcessorPipeline([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Create unbatched observation
|
||||
observation = {
|
||||
@@ -263,7 +266,7 @@ def test_integration_with_robot_processor():
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test get_config, state_dict, load_state_dict, and reset methods."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
@@ -283,9 +286,9 @@ def test_serialization_methods():
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading ToBatchProcessor with RobotProcessor."""
|
||||
processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor(
|
||||
"""Test saving and loading AddBatchDimensionProcessorStep with RobotProcessor."""
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
pipeline = DataProcessorPipeline(
|
||||
[processor], name="BatchPipeline", to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -298,13 +301,13 @@ def test_save_and_load_pretrained():
|
||||
assert config_path.exists()
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
assert loaded_pipeline.name == "BatchPipeline"
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_pipeline.steps[0], ToBatchProcessor)
|
||||
assert isinstance(loaded_pipeline.steps[0], AddBatchDimensionProcessorStep)
|
||||
|
||||
# Test functionality of loaded processor
|
||||
observation = {OBS_STATE: torch.randn(5)}
|
||||
@@ -315,10 +318,10 @@ def test_save_and_load_pretrained():
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
"""Test that ToBatchProcessor is properly registered."""
|
||||
"""Test that AddBatchDimensionProcessorStep is properly registered."""
|
||||
# Check that the processor is registered
|
||||
registered_class = ProcessorStepRegistry.get("to_batch_processor")
|
||||
assert registered_class is ToBatchProcessor
|
||||
assert registered_class is AddBatchDimensionProcessorStep
|
||||
|
||||
# Check that it's in the list of registered processors
|
||||
assert "to_batch_processor" in ProcessorStepRegistry.list()
|
||||
@@ -326,12 +329,12 @@ def test_registry_functionality():
|
||||
|
||||
def test_registry_based_save_load():
|
||||
"""Test saving and loading using registry name."""
|
||||
processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
pipeline = DataProcessorPipeline([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -352,7 +355,7 @@ def test_registry_based_save_load():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_device_compatibility():
|
||||
"""Test processor works with tensors on different devices."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create tensors on GPU
|
||||
state_1d = torch.randn(7, device="cuda")
|
||||
@@ -376,7 +379,7 @@ def test_device_compatibility():
|
||||
|
||||
def test_processor_preserves_other_transition_keys():
|
||||
"""Test that processor only modifies observation and preserves other transition keys."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
action = torch.randn(5)
|
||||
reward = 1.5
|
||||
@@ -413,7 +416,7 @@ def test_processor_preserves_other_transition_keys():
|
||||
|
||||
def test_edge_case_zero_dimensional_tensors():
|
||||
"""Test processor handles 0D tensors (scalars) correctly."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# 0D tensors should not be modified
|
||||
scalar_tensor = torch.tensor(42.0)
|
||||
@@ -435,7 +438,7 @@ def test_edge_case_zero_dimensional_tensors():
|
||||
# Action-specific tests
|
||||
def test_action_1d_to_2d():
|
||||
"""Test that 1D action tensors get batch dimension added."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create 1D action tensor
|
||||
action_1d = torch.randn(4)
|
||||
@@ -450,7 +453,7 @@ def test_action_1d_to_2d():
|
||||
|
||||
def test_action_already_batched():
|
||||
"""Test that already batched action tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test various batch sizes
|
||||
action_batched_1 = torch.randn(1, 4)
|
||||
@@ -469,7 +472,7 @@ def test_action_already_batched():
|
||||
|
||||
def test_action_higher_dimensional():
|
||||
"""Test that higher dimensional action tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# 3D action tensor (e.g., sequence of actions)
|
||||
action_3d = torch.randn(2, 4, 3)
|
||||
@@ -486,7 +489,7 @@ def test_action_higher_dimensional():
|
||||
|
||||
def test_action_scalar_tensor():
|
||||
"""Test that scalar (0D) action tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
action_scalar = torch.tensor(1.5)
|
||||
transition = create_transition(action=action_scalar)
|
||||
@@ -499,7 +502,7 @@ def test_action_scalar_tensor():
|
||||
|
||||
def test_action_non_tensor():
|
||||
"""Test that non-tensor actions remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# List action
|
||||
action_list = [0.1, 0.2, 0.3, 0.4]
|
||||
@@ -528,7 +531,7 @@ def test_action_non_tensor():
|
||||
|
||||
def test_action_none():
|
||||
"""Test that None action is handled correctly."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(action=None)
|
||||
result = processor(transition)
|
||||
@@ -537,7 +540,7 @@ def test_action_none():
|
||||
|
||||
def test_action_with_observation():
|
||||
"""Test action processing together with observation processing."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Both need batching
|
||||
observation = {
|
||||
@@ -557,7 +560,7 @@ def test_action_with_observation():
|
||||
|
||||
def test_action_different_sizes():
|
||||
"""Test action processing with various action dimensions."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Different action sizes (robot with different DOF)
|
||||
action_sizes = [1, 2, 4, 7, 10, 20]
|
||||
@@ -574,7 +577,7 @@ def test_action_different_sizes():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_action_device_compatibility():
|
||||
"""Test action processing on different devices."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# CUDA action
|
||||
action_cuda = torch.randn(4, device="cuda")
|
||||
@@ -595,7 +598,7 @@ def test_action_device_compatibility():
|
||||
|
||||
def test_action_dtype_preservation():
|
||||
"""Test that action dtype is preserved during processing."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Different dtypes
|
||||
dtypes = [torch.float32, torch.float64, torch.int32, torch.int64]
|
||||
@@ -611,7 +614,7 @@ def test_action_dtype_preservation():
|
||||
|
||||
def test_empty_action_tensor():
|
||||
"""Test handling of empty action tensors."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Empty 1D tensor
|
||||
action_empty = torch.tensor([])
|
||||
@@ -633,7 +636,7 @@ def test_empty_action_tensor():
|
||||
# Task-specific tests
|
||||
def test_task_string_to_list():
|
||||
"""Test that string tasks get wrapped in lists to add batch dimension."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create complementary data with string task
|
||||
complementary_data = {"task": "pick_cube"}
|
||||
@@ -650,7 +653,7 @@ def test_task_string_to_list():
|
||||
|
||||
def test_task_string_validation():
|
||||
"""Test that only string and list of strings are valid task values."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Valid string task - should be converted to list
|
||||
complementary_data = {"task": "valid_task"}
|
||||
@@ -669,7 +672,7 @@ def test_task_string_validation():
|
||||
|
||||
def test_task_list_of_strings():
|
||||
"""Test that lists of strings remain unchanged (already batched)."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test various list of strings
|
||||
test_lists = [
|
||||
@@ -695,7 +698,7 @@ def test_task_list_of_strings():
|
||||
|
||||
def test_complementary_data_none():
|
||||
"""Test processor handles None complementary_data gracefully."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(complementary_data=None)
|
||||
result = processor(transition)
|
||||
@@ -705,7 +708,7 @@ def test_complementary_data_none():
|
||||
|
||||
def test_complementary_data_empty():
|
||||
"""Test processor handles empty complementary_data dict."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
@@ -717,7 +720,7 @@ def test_complementary_data_empty():
|
||||
|
||||
def test_complementary_data_no_task():
|
||||
"""Test processor handles complementary_data without task field."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {
|
||||
"episode_id": 123,
|
||||
@@ -735,7 +738,7 @@ def test_complementary_data_no_task():
|
||||
|
||||
def test_complementary_data_mixed():
|
||||
"""Test processor with mixed complementary_data containing task and other fields."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {
|
||||
"task": "stack_blocks",
|
||||
@@ -760,7 +763,7 @@ def test_complementary_data_mixed():
|
||||
|
||||
def test_task_with_observation_and_action():
|
||||
"""Test task processing together with observation and action processing."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# All components need batching
|
||||
observation = {
|
||||
@@ -785,7 +788,7 @@ def test_task_with_observation_and_action():
|
||||
|
||||
def test_task_comprehensive_string_cases():
|
||||
"""Test task processing with comprehensive string cases and edge cases."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test various string formats
|
||||
string_tasks = [
|
||||
@@ -843,7 +846,7 @@ def test_task_comprehensive_string_cases():
|
||||
|
||||
def test_task_preserves_other_keys():
|
||||
"""Test that task processing preserves other keys in complementary_data."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {
|
||||
"task": "clean_table",
|
||||
@@ -871,7 +874,7 @@ def test_task_preserves_other_keys():
|
||||
# Index and task_index specific tests
|
||||
def test_index_scalar_to_1d():
|
||||
"""Test that 0D index tensor gets unsqueezed to 1D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create 0D index tensor (scalar)
|
||||
index_0d = torch.tensor(42, dtype=torch.int64)
|
||||
@@ -888,7 +891,7 @@ def test_index_scalar_to_1d():
|
||||
|
||||
def test_task_index_scalar_to_1d():
|
||||
"""Test that 0D task_index tensor gets unsqueezed to 1D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create 0D task_index tensor (scalar)
|
||||
task_index_0d = torch.tensor(7, dtype=torch.int64)
|
||||
@@ -905,7 +908,7 @@ def test_task_index_scalar_to_1d():
|
||||
|
||||
def test_index_and_task_index_together():
|
||||
"""Test processing both index and task_index together."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create 0D tensors for both
|
||||
index_0d = torch.tensor(100, dtype=torch.int64)
|
||||
@@ -935,7 +938,7 @@ def test_index_and_task_index_together():
|
||||
|
||||
def test_index_already_batched():
|
||||
"""Test that already batched index tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create already batched tensors
|
||||
index_1d = torch.tensor([42], dtype=torch.int64)
|
||||
@@ -956,7 +959,7 @@ def test_index_already_batched():
|
||||
|
||||
def test_task_index_already_batched():
|
||||
"""Test that already batched task_index tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create already batched tensors
|
||||
task_index_1d = torch.tensor([7], dtype=torch.int64)
|
||||
@@ -977,7 +980,7 @@ def test_task_index_already_batched():
|
||||
|
||||
def test_index_non_tensor_unchanged():
|
||||
"""Test that non-tensor index values remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {
|
||||
"index": 42, # Plain int, not tensor
|
||||
@@ -994,7 +997,7 @@ def test_index_non_tensor_unchanged():
|
||||
|
||||
def test_index_dtype_preservation():
|
||||
"""Test that index and task_index dtype is preserved during processing."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test different dtypes
|
||||
dtypes = [torch.int32, torch.int64, torch.long]
|
||||
@@ -1017,7 +1020,7 @@ def test_index_dtype_preservation():
|
||||
|
||||
def test_index_with_full_transition():
|
||||
"""Test index/task_index processing with full transition data."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create full transition with all components
|
||||
observation = {
|
||||
@@ -1059,7 +1062,7 @@ def test_index_with_full_transition():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_index_device_compatibility():
|
||||
"""Test processor works with index/task_index tensors on different devices."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create tensors on GPU
|
||||
index_0d = torch.tensor(42, dtype=torch.int64, device="cuda")
|
||||
@@ -1083,7 +1086,7 @@ def test_index_device_compatibility():
|
||||
|
||||
def test_empty_index_tensor():
|
||||
"""Test handling of empty index tensors."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Empty 0D tensor doesn't make sense, but test empty 1D
|
||||
index_empty = torch.tensor([], dtype=torch.int64)
|
||||
@@ -1098,7 +1101,7 @@ def test_empty_index_tensor():
|
||||
|
||||
def test_action_processing_creates_new_transition():
|
||||
"""Test that the processor creates a new transition object with correctly processed action."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(action=action)
|
||||
@@ -1120,7 +1123,7 @@ def test_action_processing_creates_new_transition():
|
||||
|
||||
def test_task_processing_creates_new_transition():
|
||||
"""Test that the processor creates a new transition object with correctly processed task."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {"task": "sort_objects"}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
@@ -24,8 +24,13 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
from lerobot.processor import DeviceProcessor, IdentityProcessor, NormalizerProcessor, RobotProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
@@ -82,14 +87,14 @@ def test_make_classifier_processor_basic():
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 3
|
||||
assert isinstance(preprocessor.steps[0], NormalizerProcessor) # For input features
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor) # For output features
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessor)
|
||||
assert isinstance(preprocessor.steps[0], NormalizerProcessorStep) # For input features
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) # For output features
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], IdentityProcessor)
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], IdentityProcessorStep)
|
||||
|
||||
|
||||
def test_classifier_processor_normalization():
|
||||
@@ -249,7 +254,7 @@ def test_classifier_processor_save_and_load():
|
||||
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
preprocessor = DataProcessorPipeline(
|
||||
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -258,7 +263,7 @@ def test_classifier_processor_save_and_load():
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -286,16 +291,16 @@ def test_classifier_processor_mixed_precision():
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in factory_preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessor):
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
preprocessor = DataProcessorPipeline(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
|
||||
+120
-102
@@ -2,112 +2,16 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.processor.converters import (
|
||||
to_dataset_frame,
|
||||
to_output_robot_action,
|
||||
batch_to_transition,
|
||||
to_tensor,
|
||||
to_transition_robot_observation,
|
||||
to_transition_teleop_action,
|
||||
transition_to_batch,
|
||||
transition_to_dataset_frame,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def test_to_transition_teleop_action_prefix_and_tensor_conversion():
|
||||
# Scalars, arrays, and uint8 arrays are all converted to tensors
|
||||
img = np.zeros((8, 12, 3), dtype=np.uint8)
|
||||
act = {
|
||||
"ee.x": 0.5, # scalar to torch tensor
|
||||
"delta": np.array([1.0, 2.0]), # ndarray to torch tensor
|
||||
"raw_img": img, # uint8 HWC to torch tensor
|
||||
}
|
||||
|
||||
tr = to_transition_teleop_action(act)
|
||||
|
||||
# Should be an EnvTransition-like dict with ACTION populated
|
||||
assert isinstance(tr, dict)
|
||||
assert TransitionKey.ACTION in tr
|
||||
assert "action.ee.x" in tr[TransitionKey.ACTION]
|
||||
assert "action.delta" in tr[TransitionKey.ACTION]
|
||||
assert "action.raw_img" in tr[TransitionKey.ACTION]
|
||||
|
||||
# Types: all values -> torch tensor
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5)
|
||||
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.delta"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,)
|
||||
assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0]))
|
||||
|
||||
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], torch.Tensor)
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == torch.float32 # converted from uint8
|
||||
assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3)
|
||||
|
||||
# Observation is created as empty dict by make_transition
|
||||
assert TransitionKey.OBSERVATION in tr
|
||||
assert isinstance(tr[TransitionKey.OBSERVATION], dict)
|
||||
assert tr[TransitionKey.OBSERVATION] == {}
|
||||
|
||||
|
||||
def test_to_transition_robot_observation_state_vs_images_split():
|
||||
# Create an observation with mixed content
|
||||
img = np.full((10, 20, 3), 255, dtype=np.uint8) # image (uint8 HWC)
|
||||
obs = {
|
||||
"j1.pos": 10.0, # scalar to state to torch tensor
|
||||
"j2.pos": np.float32(20.0), # scalar np to state to torch tensor
|
||||
"image_front": img, # to images passthrough
|
||||
"flag": np.int32(7), # scalar to state to torch tensor
|
||||
"arr": np.array([1.5, 2.5]), # vector to state to torch tensor
|
||||
}
|
||||
|
||||
tr = to_transition_robot_observation(obs)
|
||||
assert isinstance(tr, dict)
|
||||
assert TransitionKey.OBSERVATION in tr
|
||||
|
||||
out = tr[TransitionKey.OBSERVATION]
|
||||
# Check state keys are present and converted to tensors
|
||||
for k in ("j1.pos", "j2.pos", "flag", "arr"):
|
||||
key = f"observation.state.{k}"
|
||||
assert key in out
|
||||
v = out[key]
|
||||
if k != "arr":
|
||||
assert isinstance(v, torch.Tensor) and v.ndim == 0
|
||||
else:
|
||||
assert isinstance(v, torch.Tensor) and v.ndim == 1 and v.shape == (2,)
|
||||
|
||||
# Check image present as is
|
||||
assert "observation.images.image_front" in out
|
||||
assert isinstance(out["observation.images.image_front"], np.ndarray)
|
||||
assert out["observation.images.image_front"].dtype == np.uint8
|
||||
assert out["observation.images.image_front"].shape == (10, 20, 3)
|
||||
|
||||
# ACTION should be empty dict by make_transition
|
||||
assert TransitionKey.ACTION in tr
|
||||
assert isinstance(tr[TransitionKey.ACTION], dict)
|
||||
assert tr[TransitionKey.ACTION] == {}
|
||||
|
||||
|
||||
def test_to_output_robot_action_strips_prefix_and_filters_pos_keys_only():
|
||||
# Build a transition with mixed action keys
|
||||
tr = {
|
||||
TransitionKey.ACTION: {
|
||||
"action.j1.pos": 11.0, # keep "j1.pos"
|
||||
"action.gripper.pos": torch.tensor(33.0), # keep: tensor accepted
|
||||
"action.ee.x": 0.5, # ignore (doesn't end with .pos)
|
||||
"misc": "ignore_me", # ignore (no 'action.' prefix)
|
||||
}
|
||||
}
|
||||
|
||||
out = to_output_robot_action(tr)
|
||||
# Only ".pos" keys with "action." prefix are retained and stripped to base names
|
||||
assert set(out.keys()) == {"j1.pos", "gripper.pos"}
|
||||
# Values converted to float
|
||||
assert isinstance(out["j1.pos"], float)
|
||||
assert isinstance(out["gripper.pos"], float)
|
||||
assert out["j1.pos"] == pytest.approx(11.0)
|
||||
assert out["gripper.pos"] == pytest.approx(33.0)
|
||||
|
||||
|
||||
def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
def test_transition_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
# Fabricate dataset features (as stored in dataset.meta["features"])
|
||||
features = {
|
||||
# Action vector: 3 elements in specific order
|
||||
@@ -160,7 +64,7 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
|
||||
}
|
||||
|
||||
# Directly call the refactored function
|
||||
batch = to_dataset_frame([teleop_transition, robot_transition], features)
|
||||
batch = transition_to_dataset_frame([teleop_transition, robot_transition], features)
|
||||
|
||||
# Images passthrough
|
||||
assert "observation.images.front" in batch
|
||||
@@ -377,3 +281,117 @@ def test_to_tensor_unsupported_type():
|
||||
|
||||
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
|
||||
to_tensor(object())
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info if info is not None else {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def test_batch_to_transition_with_index_fields():
|
||||
"""Test that batch_to_transition handles index and task_index fields correctly."""
|
||||
|
||||
# Create batch with index and task_index fields
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
"task": ["pick_cube"],
|
||||
"index": torch.tensor([42], dtype=torch.int64),
|
||||
"task_index": torch.tensor([3], dtype=torch.int64),
|
||||
}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check basic transition structure
|
||||
assert TransitionKey.OBSERVATION in transition
|
||||
assert TransitionKey.ACTION in transition
|
||||
assert TransitionKey.COMPLEMENTARY_DATA in transition
|
||||
|
||||
# Check that index and task_index are in complementary_data
|
||||
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert "index" in comp_data
|
||||
assert "task_index" in comp_data
|
||||
assert "task" in comp_data
|
||||
|
||||
# Verify values
|
||||
assert torch.equal(comp_data["index"], batch["index"])
|
||||
assert torch.equal(comp_data["task_index"], batch["task_index"])
|
||||
assert comp_data["task"] == batch["task"]
|
||||
|
||||
|
||||
def testtransition_to_batch_with_index_fields():
|
||||
"""Test that transition_to_batch handles index and task_index fields correctly."""
|
||||
|
||||
# Create transition with index and task_index in complementary_data
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
reward=1.5,
|
||||
done=False,
|
||||
complementary_data={
|
||||
"task": ["navigate"],
|
||||
"index": torch.tensor([100], dtype=torch.int64),
|
||||
"task_index": torch.tensor([5], dtype=torch.int64),
|
||||
},
|
||||
)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that index and task_index are in the batch
|
||||
assert "index" in batch
|
||||
assert "task_index" in batch
|
||||
assert "task" in batch
|
||||
|
||||
# Verify values
|
||||
assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"])
|
||||
assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"])
|
||||
assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"]
|
||||
|
||||
|
||||
def test_batch_to_transition_without_index_fields():
|
||||
"""Test that conversion works without index and task_index fields."""
|
||||
|
||||
# Batch without index/task_index
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"task": ["pick_cube"],
|
||||
}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Should have task but not index/task_index
|
||||
assert "task" in comp_data
|
||||
assert "index" not in comp_data
|
||||
assert "task_index" not in comp_data
|
||||
|
||||
|
||||
def test_transition_to_batch_without_index_fields():
|
||||
"""Test that conversion works without index and task_index fields."""
|
||||
|
||||
# Transition without index/task_index
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
complementary_data={"task": ["navigate"]},
|
||||
)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Should have task but not index/task_index
|
||||
assert "task" in batch
|
||||
assert "index" not in batch
|
||||
assert "task_index" not in batch
|
||||
|
||||
@@ -18,9 +18,8 @@ import tempfile
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.processor import DeviceProcessor, RobotProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
@@ -47,7 +46,7 @@ def create_transition(
|
||||
|
||||
def test_basic_functionality():
|
||||
"""Test basic device processor functionality on CPU."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
# Create a transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
|
||||
@@ -74,7 +73,7 @@ def test_basic_functionality():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_cuda_functionality():
|
||||
"""Test device processor functionality on CUDA."""
|
||||
processor = DeviceProcessor(device="cuda")
|
||||
processor = DeviceProcessorStep(device="cuda")
|
||||
|
||||
# Create a transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
|
||||
@@ -101,7 +100,7 @@ def test_cuda_functionality():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_specific_cuda_device():
|
||||
"""Test device processor with specific CUDA device."""
|
||||
processor = DeviceProcessor(device="cuda:0")
|
||||
processor = DeviceProcessorStep(device="cuda:0")
|
||||
|
||||
observation = {"observation.state": torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
@@ -117,7 +116,7 @@ def test_specific_cuda_device():
|
||||
|
||||
def test_non_tensor_values():
|
||||
"""Test that non-tensor values are preserved."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
@@ -143,7 +142,7 @@ def test_non_tensor_values():
|
||||
|
||||
def test_none_values():
|
||||
"""Test handling of None values."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
# Test with None observation
|
||||
transition = create_transition(observation=None, action=torch.randn(5))
|
||||
@@ -160,7 +159,7 @@ def test_none_values():
|
||||
|
||||
def test_empty_observation():
|
||||
"""Test handling of empty observation dictionary."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition = create_transition(observation={}, action=torch.randn(5))
|
||||
result = processor(transition)
|
||||
@@ -171,7 +170,7 @@ def test_empty_observation():
|
||||
|
||||
def test_scalar_tensors():
|
||||
"""Test handling of scalar tensors."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
observation = {"observation.scalar": torch.tensor(1.5)}
|
||||
action = torch.tensor(2.0)
|
||||
@@ -188,7 +187,7 @@ def test_scalar_tensors():
|
||||
|
||||
def test_dtype_preservation():
|
||||
"""Test that tensor dtypes are preserved."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
observation = {
|
||||
"observation.float32": torch.randn(5, dtype=torch.float32),
|
||||
@@ -210,7 +209,7 @@ def test_dtype_preservation():
|
||||
|
||||
def test_shape_preservation():
|
||||
"""Test that tensor shapes are preserved."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
observation = {
|
||||
"observation.1d": torch.randn(10),
|
||||
@@ -233,7 +232,7 @@ def test_shape_preservation():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_mixed_devices():
|
||||
"""Test handling of tensors already on different devices."""
|
||||
processor = DeviceProcessor(device="cuda")
|
||||
processor = DeviceProcessorStep(device="cuda")
|
||||
|
||||
# Create tensors on different devices
|
||||
observation = {
|
||||
@@ -254,22 +253,22 @@ def test_mixed_devices():
|
||||
def test_non_blocking_flag():
|
||||
"""Test that non_blocking flag is set correctly."""
|
||||
# CPU processor should have non_blocking=False
|
||||
cpu_processor = DeviceProcessor(device="cpu")
|
||||
cpu_processor = DeviceProcessorStep(device="cpu")
|
||||
assert cpu_processor.non_blocking is False
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# CUDA processor should have non_blocking=True
|
||||
cuda_processor = DeviceProcessor(device="cuda")
|
||||
cuda_processor = DeviceProcessorStep(device="cuda")
|
||||
assert cuda_processor.non_blocking is True
|
||||
|
||||
cuda_0_processor = DeviceProcessor(device="cuda:0")
|
||||
cuda_0_processor = DeviceProcessorStep(device="cuda:0")
|
||||
assert cuda_0_processor.non_blocking is True
|
||||
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test get_config, state_dict, and load_state_dict methods."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
processor = DeviceProcessor(device=device)
|
||||
processor = DeviceProcessorStep(device=device)
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
@@ -290,11 +289,13 @@ def test_serialization_methods():
|
||||
|
||||
def test_features():
|
||||
"""Test that features returns features unchanged."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
result = processor.transform_features(features)
|
||||
@@ -305,13 +306,13 @@ def test_features():
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test integration with RobotProcessor."""
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.processor import ToBatchProcessor
|
||||
from lerobot.processor import AddBatchDimensionProcessorStep
|
||||
|
||||
# Create a pipeline with DeviceProcessor
|
||||
device_processor = DeviceProcessor(device="cpu")
|
||||
batch_processor = ToBatchProcessor()
|
||||
# Create a pipeline with DeviceProcessorStep
|
||||
device_processor = DeviceProcessorStep(device="cpu")
|
||||
batch_processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
processor = RobotProcessor(
|
||||
processor = DataProcessorPipeline(
|
||||
steps=[batch_processor, device_processor],
|
||||
name="test_pipeline",
|
||||
to_transition=lambda x: x,
|
||||
@@ -334,21 +335,21 @@ def test_integration_with_robot_processor():
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading processor with DeviceProcessor."""
|
||||
"""Test saving and loading processor with DeviceProcessorStep."""
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
processor = DeviceProcessor(device=device, float_dtype="float16")
|
||||
robot_processor = RobotProcessor(steps=[processor], name="device_test_processor")
|
||||
processor = DeviceProcessorStep(device=device, float_dtype="float16")
|
||||
robot_processor = DataProcessorPipeline(steps=[processor], name="device_test_processor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save
|
||||
robot_processor.save_pretrained(tmpdir)
|
||||
|
||||
# Load
|
||||
loaded_processor = RobotProcessor.from_pretrained(tmpdir)
|
||||
loaded_processor = DataProcessorPipeline.from_pretrained(tmpdir)
|
||||
|
||||
assert len(loaded_processor.steps) == 1
|
||||
loaded_device_processor = loaded_processor.steps[0]
|
||||
assert isinstance(loaded_device_processor, DeviceProcessor)
|
||||
assert isinstance(loaded_device_processor, DeviceProcessorStep)
|
||||
# Use getattr to access attributes safely
|
||||
assert (
|
||||
getattr(loaded_device_processor, "device", None) == device.split(":")[0]
|
||||
@@ -357,18 +358,18 @@ def test_save_and_load_pretrained():
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
"""Test that DeviceProcessor is properly registered."""
|
||||
from lerobot.processor.pipeline import ProcessorStepRegistry
|
||||
"""Test that DeviceProcessorStep is properly registered."""
|
||||
from lerobot.processor import ProcessorStepRegistry
|
||||
|
||||
# Check that DeviceProcessor is registered
|
||||
# Check that DeviceProcessorStep is registered
|
||||
registered_class = ProcessorStepRegistry.get("device_processor")
|
||||
assert registered_class is DeviceProcessor
|
||||
assert registered_class is DeviceProcessorStep
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_performance_with_large_tensors():
|
||||
"""Test performance with large tensors and non_blocking flag."""
|
||||
processor = DeviceProcessor(device="cuda")
|
||||
processor = DeviceProcessorStep(device="cuda")
|
||||
|
||||
# Create large tensors
|
||||
observation = {
|
||||
@@ -390,7 +391,7 @@ def test_performance_with_large_tensors():
|
||||
|
||||
def test_reward_done_truncated_types():
|
||||
"""Test handling of different types for reward, done, and truncated."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
# Test with scalar values (not tensors)
|
||||
transition = create_transition(
|
||||
@@ -430,7 +431,7 @@ def test_reward_done_truncated_types():
|
||||
|
||||
def test_complementary_data_preserved():
|
||||
"""Test that complementary_data is preserved unchanged."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
complementary_data = {
|
||||
"task": "pick_object",
|
||||
@@ -450,13 +451,13 @@ def test_complementary_data_preserved():
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick_object"
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["episode_id"] == 42
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["metadata"] == {"sensor": "camera_1"}
|
||||
# Note: Currently DeviceProcessor doesn't process tensors in complementary_data
|
||||
# Note: Currently DeviceProcessorStep doesn't process tensors in complementary_data
|
||||
# This is intentional as complementary_data is typically metadata
|
||||
|
||||
|
||||
def test_float_dtype_conversion():
|
||||
"""Test float dtype conversion functionality."""
|
||||
processor = DeviceProcessor(device="cpu", float_dtype="float16")
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="float16")
|
||||
|
||||
# Create tensors of different types
|
||||
observation = {
|
||||
@@ -486,7 +487,7 @@ def test_float_dtype_conversion():
|
||||
|
||||
def test_float_dtype_none():
|
||||
"""Test that when float_dtype is None, no dtype conversion occurs."""
|
||||
processor = DeviceProcessor(device="cpu", float_dtype=None)
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype=None)
|
||||
|
||||
observation = {
|
||||
"observation.float32": torch.randn(5, dtype=torch.float32),
|
||||
@@ -507,7 +508,7 @@ def test_float_dtype_none():
|
||||
|
||||
def test_float_dtype_bfloat16():
|
||||
"""Test conversion to bfloat16."""
|
||||
processor = DeviceProcessor(device="cpu", float_dtype="bfloat16")
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16")
|
||||
|
||||
observation = {"observation.state": torch.randn(5, dtype=torch.float32)}
|
||||
action = torch.randn(3, dtype=torch.float64)
|
||||
@@ -521,7 +522,7 @@ def test_float_dtype_bfloat16():
|
||||
|
||||
def test_float_dtype_float64():
|
||||
"""Test conversion to float64."""
|
||||
processor = DeviceProcessor(device="cpu", float_dtype="float64")
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="float64")
|
||||
|
||||
observation = {"observation.state": torch.randn(5, dtype=torch.float16)}
|
||||
action = torch.randn(3, dtype=torch.float32)
|
||||
@@ -536,27 +537,27 @@ def test_float_dtype_float64():
|
||||
def test_float_dtype_invalid():
|
||||
"""Test that invalid float_dtype raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid float_dtype 'invalid_dtype'"):
|
||||
DeviceProcessor(device="cpu", float_dtype="invalid_dtype")
|
||||
DeviceProcessorStep(device="cpu", float_dtype="invalid_dtype")
|
||||
|
||||
|
||||
def test_float_dtype_aliases():
|
||||
"""Test that dtype aliases work correctly."""
|
||||
# Test 'half' alias for float16
|
||||
processor_half = DeviceProcessor(device="cpu", float_dtype="half")
|
||||
processor_half = DeviceProcessorStep(device="cpu", float_dtype="half")
|
||||
assert processor_half._target_float_dtype == torch.float16
|
||||
|
||||
# Test 'float' alias for float32
|
||||
processor_float = DeviceProcessor(device="cpu", float_dtype="float")
|
||||
processor_float = DeviceProcessorStep(device="cpu", float_dtype="float")
|
||||
assert processor_float._target_float_dtype == torch.float32
|
||||
|
||||
# Test 'double' alias for float64
|
||||
processor_double = DeviceProcessor(device="cpu", float_dtype="double")
|
||||
processor_double = DeviceProcessorStep(device="cpu", float_dtype="double")
|
||||
assert processor_double._target_float_dtype == torch.float64
|
||||
|
||||
|
||||
def test_float_dtype_with_mixed_tensors():
|
||||
"""Test float dtype conversion with mixed tensor types."""
|
||||
processor = DeviceProcessor(device="cpu", float_dtype="float32")
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="float32")
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert
|
||||
@@ -580,13 +581,13 @@ def test_float_dtype_with_mixed_tensors():
|
||||
def test_float_dtype_serialization():
|
||||
"""Test that float_dtype is properly serialized in get_config."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
processor = DeviceProcessor(device=device, float_dtype="float16")
|
||||
processor = DeviceProcessorStep(device=device, float_dtype="float16")
|
||||
config = processor.get_config()
|
||||
|
||||
assert config == {"device": device, "float_dtype": "float16"}
|
||||
|
||||
# Test with None float_dtype
|
||||
processor_none = DeviceProcessor(device="cpu", float_dtype=None)
|
||||
processor_none = DeviceProcessorStep(device="cpu", float_dtype=None)
|
||||
config_none = processor_none.get_config()
|
||||
|
||||
assert config_none == {"device": "cpu", "float_dtype": None}
|
||||
@@ -595,7 +596,7 @@ def test_float_dtype_serialization():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_float_dtype_with_cuda():
|
||||
"""Test float dtype conversion combined with CUDA device."""
|
||||
processor = DeviceProcessor(device="cuda", float_dtype="float16")
|
||||
processor = DeviceProcessorStep(device="cuda", float_dtype="float16")
|
||||
|
||||
# Create tensors on CPU with different dtypes
|
||||
observation = {
|
||||
@@ -620,7 +621,7 @@ def test_float_dtype_with_cuda():
|
||||
|
||||
def test_complementary_data_index_fields():
|
||||
"""Test processing of index and task_index fields in complementary_data."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
# Create transition with index and task_index in complementary_data
|
||||
complementary_data = {
|
||||
@@ -658,7 +659,7 @@ def test_complementary_data_index_fields():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_complementary_data_index_fields_cuda():
|
||||
"""Test moving index and task_index fields to CUDA."""
|
||||
processor = DeviceProcessor(device="cuda:0")
|
||||
processor = DeviceProcessorStep(device="cuda:0")
|
||||
|
||||
# Create CPU tensors
|
||||
complementary_data = {
|
||||
@@ -680,7 +681,7 @@ def test_complementary_data_index_fields_cuda():
|
||||
|
||||
def test_complementary_data_without_index_fields():
|
||||
"""Test that complementary_data without index/task_index fields works correctly."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
complementary_data = {
|
||||
"task": ["navigate"],
|
||||
@@ -698,7 +699,7 @@ def test_complementary_data_without_index_fields():
|
||||
|
||||
def test_complementary_data_mixed_tensors():
|
||||
"""Test complementary_data with mix of tensors and non-tensors."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
complementary_data = {
|
||||
"task": ["pick_and_place"],
|
||||
@@ -727,7 +728,7 @@ def test_complementary_data_mixed_tensors():
|
||||
|
||||
def test_complementary_data_float_dtype_conversion():
|
||||
"""Test that float dtype conversion doesn't affect int tensors in complementary_data."""
|
||||
processor = DeviceProcessor(device="cpu", float_dtype="float16")
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="float16")
|
||||
|
||||
complementary_data = {
|
||||
"index": torch.tensor([42], dtype=torch.int64),
|
||||
@@ -751,7 +752,7 @@ def test_complementary_data_float_dtype_conversion():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_complementary_data_full_pipeline_cuda():
|
||||
"""Test full transition with complementary_data on CUDA."""
|
||||
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
|
||||
processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16")
|
||||
|
||||
# Create full transition with mixed CPU tensors
|
||||
observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)}
|
||||
@@ -797,7 +798,7 @@ def test_complementary_data_full_pipeline_cuda():
|
||||
|
||||
def test_complementary_data_empty():
|
||||
"""Test empty complementary_data handling."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
@@ -812,7 +813,7 @@ def test_complementary_data_empty():
|
||||
|
||||
def test_complementary_data_none():
|
||||
"""Test None complementary_data handling."""
|
||||
processor = DeviceProcessor(device="cpu")
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
@@ -827,8 +828,8 @@ def test_complementary_data_none():
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_preserves_gpu_placement():
|
||||
"""Test that DeviceProcessor preserves GPU placement when tensor is already on GPU."""
|
||||
processor = DeviceProcessor(device="cuda:0")
|
||||
"""Test that DeviceProcessorStep preserves GPU placement when tensor is already on GPU."""
|
||||
processor = DeviceProcessorStep(device="cuda:0")
|
||||
|
||||
# Create tensors already on GPU
|
||||
observation = {
|
||||
@@ -853,9 +854,9 @@ def test_preserves_gpu_placement():
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_multi_gpu_preservation():
|
||||
"""Test that DeviceProcessor preserves placement on different GPUs in multi-GPU setup."""
|
||||
"""Test that DeviceProcessorStep preserves placement on different GPUs in multi-GPU setup."""
|
||||
# Test 1: GPU-to-GPU preservation (cuda:0 config, cuda:1 input)
|
||||
processor_gpu = DeviceProcessor(device="cuda:0")
|
||||
processor_gpu = DeviceProcessorStep(device="cuda:0")
|
||||
|
||||
# Create tensors on cuda:1 (simulating Accelerate placement)
|
||||
cuda1_device = torch.device("cuda:1")
|
||||
@@ -874,7 +875,7 @@ def test_multi_gpu_preservation():
|
||||
assert result[TransitionKey.ACTION].device == cuda1_device
|
||||
|
||||
# Test 2: GPU-to-CPU should move to CPU (not preserve GPU)
|
||||
processor_cpu = DeviceProcessor(device="cpu")
|
||||
processor_cpu = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition_gpu = create_transition(
|
||||
observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda()
|
||||
@@ -890,7 +891,7 @@ def test_multi_gpu_preservation():
|
||||
def test_multi_gpu_with_cpu_tensors():
|
||||
"""Test that CPU tensors are moved to configured device even in multi-GPU context."""
|
||||
# Processor configured for cuda:1
|
||||
processor = DeviceProcessor(device="cuda:1")
|
||||
processor = DeviceProcessorStep(device="cuda:1")
|
||||
|
||||
# Mix of CPU and GPU tensors
|
||||
observation = {
|
||||
@@ -917,7 +918,7 @@ def test_multi_gpu_with_cpu_tensors():
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_multi_gpu_with_float_dtype():
|
||||
"""Test float dtype conversion works correctly with multi-GPU preservation."""
|
||||
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
|
||||
processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16")
|
||||
|
||||
# Create float tensors on different GPUs
|
||||
observation = {
|
||||
@@ -947,7 +948,7 @@ def test_simulated_accelerate_scenario():
|
||||
for gpu_id in range(min(torch.cuda.device_count(), 2)):
|
||||
# Each "process" has a processor configured for cuda:0
|
||||
# but data comes in already placed on the process's GPU
|
||||
processor = DeviceProcessor(device="cuda:0")
|
||||
processor = DeviceProcessorStep(device="cuda:0")
|
||||
|
||||
# Simulate data already placed by Accelerate
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
@@ -967,7 +968,11 @@ def test_policy_processor_integration():
|
||||
"""Test integration with policy processors - input on GPU, output on CPU."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.processor import NormalizerProcessor, ToBatchProcessor, UnnormalizerProcessor
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
# Create features and stats
|
||||
features = {
|
||||
@@ -983,11 +988,11 @@ def test_policy_processor_integration():
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
|
||||
# Create input processor (preprocessor) that moves to GPU
|
||||
input_processor = RobotProcessor(
|
||||
input_processor = DataProcessorPipeline(
|
||||
steps=[
|
||||
NormalizerProcessor(features=features, norm_map=norm_map, stats=stats),
|
||||
ToBatchProcessor(),
|
||||
DeviceProcessor(device="cuda"),
|
||||
NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device="cuda"),
|
||||
],
|
||||
name="test_preprocessor",
|
||||
to_transition=lambda x: x,
|
||||
@@ -995,10 +1000,10 @@ def test_policy_processor_integration():
|
||||
)
|
||||
|
||||
# Create output processor (postprocessor) that moves to CPU
|
||||
output_processor = RobotProcessor(
|
||||
output_processor = DataProcessorPipeline(
|
||||
steps=[
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
|
||||
],
|
||||
name="test_postprocessor",
|
||||
to_transition=lambda x: x,
|
||||
|
||||
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
@@ -84,20 +84,20 @@ def test_make_diffusion_processor_basic():
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_diffusion_processor_with_images():
|
||||
@@ -257,7 +257,7 @@ def test_diffusion_processor_save_and_load():
|
||||
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
preprocessor = DataProcessorPipeline(
|
||||
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -266,7 +266,7 @@ def test_diffusion_processor_save_and_load():
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -294,16 +294,27 @@ def test_diffusion_processor_mixed_precision():
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in factory_preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessor):
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
preprocessor = DataProcessorPipeline(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -379,3 +390,66 @@ def test_diffusion_processor_batch_consistency():
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == expected_batch
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape[0] == expected_batch
|
||||
assert processed[TransitionKey.ACTION].shape[0] == expected_batch
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_diffusion_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, _ = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in factory_preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
|
||||
# Create new processor with modified steps
|
||||
preprocessor = DataProcessorPipeline(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = modified_steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert (
|
||||
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
|
||||
) # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
|
||||
@@ -20,13 +20,15 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.normalize_processor import (
|
||||
NormalizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
IdentityProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
hotswap_stats,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
|
||||
from lerobot.processor.converters import to_tensor
|
||||
|
||||
|
||||
def create_transition(
|
||||
@@ -123,7 +125,7 @@ def _create_observation_norm_map():
|
||||
}
|
||||
|
||||
|
||||
# Fixtures for observation normalisation tests using NormalizerProcessor
|
||||
# Fixtures for observation normalisation tests using NormalizerProcessorStep
|
||||
@pytest.fixture
|
||||
def observation_stats():
|
||||
return {
|
||||
@@ -140,10 +142,10 @@ def observation_stats():
|
||||
|
||||
@pytest.fixture
|
||||
def observation_normalizer(observation_stats):
|
||||
"""Return a NormalizerProcessor that only has observation stats (no action)."""
|
||||
"""Return a NormalizerProcessorStep that only has observation stats (no action)."""
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
|
||||
return NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats)
|
||||
|
||||
|
||||
def test_mean_std_normalization(observation_normalizer):
|
||||
@@ -180,7 +182,7 @@ def test_min_max_normalization(observation_normalizer):
|
||||
def test_selective_normalization(observation_stats):
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
normalizer = NormalizerProcessor(
|
||||
normalizer = NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=observation_stats,
|
||||
@@ -206,7 +208,7 @@ def test_selective_normalization(observation_stats):
|
||||
def test_device_compatibility(observation_stats):
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats)
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
||||
}
|
||||
@@ -235,7 +237,7 @@ def test_from_lerobot_dataset():
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
normalizer = NormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
|
||||
# Both observation and action statistics should be present in tensor stats
|
||||
assert "observation.image" in normalizer._tensor_stats
|
||||
@@ -250,7 +252,7 @@ def test_state_dict_save_load(observation_normalizer):
|
||||
# Create new normalizer and load state
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
||||
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||
new_normalizer.load_state_dict(state_dict)
|
||||
|
||||
# Test that it works the same
|
||||
@@ -301,7 +303,7 @@ def _create_action_norm_map_min_max():
|
||||
def test_mean_std_unnormalization(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
)
|
||||
|
||||
@@ -319,7 +321,7 @@ def test_mean_std_unnormalization(action_stats_mean_std):
|
||||
def test_min_max_unnormalization(action_stats_min_max):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_min_max()
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
|
||||
)
|
||||
|
||||
@@ -345,7 +347,7 @@ def test_min_max_unnormalization(action_stats_min_max):
|
||||
def test_numpy_action_input(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
)
|
||||
|
||||
@@ -363,7 +365,7 @@ def test_numpy_action_input(action_stats_mean_std):
|
||||
def test_none_action(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
)
|
||||
|
||||
@@ -379,11 +381,11 @@ def test_action_from_lerobot_dataset():
|
||||
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
unnormalizer = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
assert "mean" in unnormalizer._tensor_stats["action"]
|
||||
|
||||
|
||||
# Fixtures for NormalizerProcessor tests
|
||||
# Fixtures for NormalizerProcessorStep tests
|
||||
@pytest.fixture
|
||||
def full_stats():
|
||||
return {
|
||||
@@ -422,7 +424,7 @@ def _create_full_norm_map():
|
||||
def normalizer_processor(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats)
|
||||
return NormalizerProcessorStep(features=features, norm_map=norm_map, stats=full_stats)
|
||||
|
||||
|
||||
def test_combined_normalization(normalizer_processor):
|
||||
@@ -466,7 +468,7 @@ def test_processor_from_lerobot_dataset(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
|
||||
processor = NormalizerProcessor.from_lerobot_dataset(
|
||||
processor = NormalizerProcessorStep.from_lerobot_dataset(
|
||||
mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"}
|
||||
)
|
||||
|
||||
@@ -478,7 +480,7 @@ def test_processor_from_lerobot_dataset(full_stats):
|
||||
def test_get_config(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
processor = NormalizerProcessor(
|
||||
processor = NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=full_stats,
|
||||
@@ -506,7 +508,9 @@ def test_get_config(full_stats):
|
||||
|
||||
def test_integration_with_robot_processor(normalizer_processor):
|
||||
"""Test integration with RobotProcessor pipeline"""
|
||||
robot_processor = RobotProcessor([normalizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[normalizer_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -535,7 +539,7 @@ def test_empty_observation():
|
||||
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
transition = create_transition()
|
||||
result = normalizer(transition)
|
||||
@@ -546,7 +550,7 @@ def test_empty_observation():
|
||||
def test_empty_stats():
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||
observation = {"observation.image": torch.tensor([0.5])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
@@ -562,7 +566,7 @@ def test_partial_stats():
|
||||
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
observation = {"observation.image": torch.tensor([0.7])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
@@ -577,7 +581,7 @@ def test_missing_action_stats_no_error():
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
# The tensor stats should not contain the 'action' key
|
||||
assert "action" not in processor._tensor_stats
|
||||
|
||||
@@ -586,7 +590,7 @@ def test_serialization_roundtrip(full_stats):
|
||||
"""Test that features and norm_map can be serialized and deserialized correctly."""
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
original_processor = NormalizerProcessor(
|
||||
original_processor = NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=full_stats,
|
||||
@@ -598,7 +602,7 @@ def test_serialization_roundtrip(full_stats):
|
||||
config = original_processor.get_config()
|
||||
|
||||
# Create a new processor from the config (deserialization)
|
||||
new_processor = NormalizerProcessor(
|
||||
new_processor = NormalizerProcessorStep(
|
||||
features=config["features"],
|
||||
norm_map=config["norm_map"],
|
||||
stats=full_stats,
|
||||
@@ -666,7 +670,7 @@ def test_identity_normalization_observations():
|
||||
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -691,7 +695,7 @@ def test_identity_normalization_actions():
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
|
||||
stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(action=action)
|
||||
@@ -717,7 +721,7 @@ def test_identity_unnormalization_observations():
|
||||
"observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -744,7 +748,7 @@ def test_identity_unnormalization_actions():
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY}
|
||||
stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}}
|
||||
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
action = torch.tensor([0.5, -0.8]) # Normalized values
|
||||
transition = create_transition(action=action)
|
||||
@@ -767,8 +771,8 @@ def test_identity_with_missing_stats():
|
||||
}
|
||||
stats = {} # No stats provided
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
@@ -808,7 +812,7 @@ def test_identity_mixed_with_other_modes():
|
||||
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -850,7 +854,7 @@ def test_identity_defaults_when_not_in_norm_map():
|
||||
"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -884,8 +888,8 @@ def test_identity_roundtrip():
|
||||
"action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
original_action = torch.tensor([0.5, -0.2])
|
||||
@@ -917,7 +921,7 @@ def test_identity_config_serialization():
|
||||
"action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Get config
|
||||
config = normalizer.get_config()
|
||||
@@ -927,7 +931,7 @@ def test_identity_config_serialization():
|
||||
assert config["norm_map"]["ACTION"] == "MEAN_STD"
|
||||
|
||||
# Create new processor from config (simulating load)
|
||||
new_normalizer = NormalizerProcessor(
|
||||
new_normalizer = NormalizerProcessorStep(
|
||||
features=config["features"],
|
||||
norm_map=config["norm_map"],
|
||||
stats=stats,
|
||||
@@ -965,7 +969,7 @@ def test_identity_config_serialization():
|
||||
# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
|
||||
|
||||
# normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
# normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# # Manually inject an invalid mode to test error handling
|
||||
# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
|
||||
@@ -1002,12 +1006,12 @@ def test_hotswap_stats_basic_functionality():
|
||||
}
|
||||
|
||||
# Create processors
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
identity = IdentityProcessor()
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
identity = IdentityProcessorStep()
|
||||
|
||||
# Create robot processor
|
||||
robot_processor = RobotProcessor(steps=[normalizer, unnormalizer, identity])
|
||||
robot_processor = DataProcessorPipeline(steps=[normalizer, unnormalizer, identity])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
@@ -1043,8 +1047,8 @@ def test_hotswap_stats_deep_copy():
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
original_processor = RobotProcessor(steps=[normalizer])
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
original_processor = DataProcessorPipeline(steps=[normalizer])
|
||||
|
||||
# Store reference to original stats
|
||||
original_stats_reference = original_processor.steps[0].stats
|
||||
@@ -1068,7 +1072,7 @@ def test_hotswap_stats_deep_copy():
|
||||
|
||||
|
||||
def test_hotswap_stats_only_affects_normalizer_steps():
|
||||
"""Test that hotswap_stats only modifies NormalizerProcessor and UnnormalizerProcessor steps."""
|
||||
"""Test that hotswap_stats only modifies NormalizerProcessorStep and UnnormalizerProcessorStep steps."""
|
||||
stats = {
|
||||
"observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])},
|
||||
}
|
||||
@@ -1083,11 +1087,11 @@ def test_hotswap_stats_only_affects_normalizer_steps():
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
# Create mixed steps
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
identity = IdentityProcessor()
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
identity = IdentityProcessorStep()
|
||||
|
||||
robot_processor = RobotProcessor(steps=[normalizer, identity, unnormalizer])
|
||||
robot_processor = DataProcessorPipeline(steps=[normalizer, identity, unnormalizer])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
@@ -1113,8 +1117,8 @@ def test_hotswap_stats_empty_stats():
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
robot_processor = RobotProcessor(steps=[normalizer])
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
robot_processor = DataProcessorPipeline(steps=[normalizer])
|
||||
|
||||
# Hotswap with empty stats
|
||||
new_processor = hotswap_stats(robot_processor, empty_stats)
|
||||
@@ -1131,7 +1135,7 @@ def test_hotswap_stats_no_normalizer_steps():
|
||||
}
|
||||
|
||||
# Create processor with only identity steps
|
||||
robot_processor = RobotProcessor(steps=[IdentityProcessor(), IdentityProcessor()])
|
||||
robot_processor = DataProcessorPipeline(steps=[IdentityProcessorStep(), IdentityProcessorStep()])
|
||||
|
||||
# Hotswap stats - should work without error
|
||||
new_processor = hotswap_stats(robot_processor, stats)
|
||||
@@ -1163,14 +1167,14 @@ def test_hotswap_stats_preserves_other_attributes():
|
||||
normalize_observation_keys = {"observation.image"}
|
||||
eps = 1e-6
|
||||
|
||||
normalizer = NormalizerProcessor(
|
||||
normalizer = NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=initial_stats,
|
||||
normalize_observation_keys=normalize_observation_keys,
|
||||
eps=eps,
|
||||
)
|
||||
robot_processor = RobotProcessor(steps=[normalizer])
|
||||
robot_processor = DataProcessorPipeline(steps=[normalizer])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
@@ -1208,12 +1212,12 @@ def test_hotswap_stats_multiple_normalizer_types():
|
||||
}
|
||||
|
||||
# Create multiple normalizers and unnormalizers
|
||||
normalizer1 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
normalizer2 = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer1 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer2 = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer1 = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
unnormalizer2 = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
|
||||
robot_processor = RobotProcessor(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2])
|
||||
robot_processor = DataProcessorPipeline(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
@@ -1260,8 +1264,8 @@ def test_hotswap_stats_with_different_data_types():
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
robot_processor = RobotProcessor(steps=[normalizer])
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
robot_processor = DataProcessorPipeline(steps=[normalizer])
|
||||
|
||||
# Hotswap stats
|
||||
new_processor = hotswap_stats(robot_processor, new_stats)
|
||||
@@ -1316,8 +1320,10 @@ def test_hotswap_stats_functional_test():
|
||||
}
|
||||
|
||||
# Create original processor
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
original_processor = RobotProcessor(steps=[normalizer], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
original_processor = DataProcessorPipeline(
|
||||
steps=[normalizer], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Process with original stats
|
||||
original_result = original_processor(transition)
|
||||
@@ -1360,7 +1366,7 @@ def test_zero_std_uses_eps():
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
|
||||
|
||||
observation = {"observation.state": torch.tensor([0.5])} # equals mean
|
||||
out = normalizer(create_transition(observation=observation))
|
||||
@@ -1372,7 +1378,7 @@ def test_min_equals_max_maps_to_minus_one():
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX}
|
||||
stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6)
|
||||
|
||||
observation = {"observation.state": torch.tensor([2.0])}
|
||||
out = normalizer(create_transition(observation=observation))
|
||||
@@ -1387,7 +1393,7 @@ def test_action_normalized_despite_normalize_observation_keys():
|
||||
}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
normalizer = NormalizerProcessor(
|
||||
normalizer = NormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"}
|
||||
)
|
||||
|
||||
@@ -1405,12 +1411,12 @@ def test_unnormalize_observations_mean_std_and_min_max():
|
||||
"observation.mm": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
# Build two processors: one mean/std and one min/max
|
||||
unnorm_ms = UnnormalizerProcessor(
|
||||
unnorm_ms = UnnormalizerProcessorStep(
|
||||
features={"observation.ms": features["observation.ms"]},
|
||||
norm_map={FeatureType.STATE: NormalizationMode.MEAN_STD},
|
||||
stats={"observation.ms": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}},
|
||||
)
|
||||
unnorm_mm = UnnormalizerProcessor(
|
||||
unnorm_mm = UnnormalizerProcessorStep(
|
||||
features={"observation.mm": features["observation.mm"]},
|
||||
norm_map={FeatureType.STATE: NormalizationMode.MIN_MAX},
|
||||
stats={"observation.mm": {"min": np.array([0.0, -2.0]), "max": np.array([2.0, 2.0])}},
|
||||
@@ -1432,7 +1438,7 @@ def test_unknown_observation_keys_ignored():
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])}
|
||||
tr = create_transition(observation=obs)
|
||||
@@ -1446,7 +1452,7 @@ def test_batched_action_normalization():
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (2,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1]
|
||||
out = normalizer(create_transition(action=actions))[TransitionKey.ACTION]
|
||||
@@ -1458,7 +1464,7 @@ def test_complementary_data_preservation():
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
comp = {"existing": 123}
|
||||
tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp)
|
||||
@@ -1477,8 +1483,8 @@ def test_roundtrip_normalize_unnormalize_non_identity():
|
||||
"observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])},
|
||||
"action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])},
|
||||
}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Add a time dimension in action for broadcasting check (B,T,D)
|
||||
obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])}
|
||||
@@ -1491,3 +1497,205 @@ def test_roundtrip_normalize_unnormalize_non_identity():
|
||||
out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5
|
||||
)
|
||||
assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5)
|
||||
|
||||
|
||||
def test_dtype_adaptation_bfloat16_input_float32_normalizer():
|
||||
"""Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output"""
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]),
|
||||
"std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
|
||||
# Create normalizer configured with float32 dtype
|
||||
normalizer = NormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats=stats, dtype=torch.float32
|
||||
)
|
||||
|
||||
# Verify initial configuration
|
||||
assert normalizer.dtype == torch.float32
|
||||
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
|
||||
assert stat_tensor.dtype == torch.float32
|
||||
|
||||
# Create bfloat16 input tensor
|
||||
observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Process the transition
|
||||
result = normalizer(transition)
|
||||
|
||||
# Verify that:
|
||||
# 1. Stats were automatically adapted to bfloat16
|
||||
assert normalizer.dtype == torch.bfloat16
|
||||
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
|
||||
# 2. Output is in bfloat16
|
||||
output_tensor = result[TransitionKey.OBSERVATION]["observation.state"]
|
||||
assert output_tensor.dtype == torch.bfloat16
|
||||
|
||||
# 3. Normalization was applied correctly (mean should be close to original - mean) / std
|
||||
expected = (
|
||||
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)
|
||||
- torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.bfloat16)
|
||||
) / torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.bfloat16)
|
||||
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
|
||||
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
|
||||
from lerobot.processor import DeviceProcessorStep
|
||||
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}}
|
||||
|
||||
# Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32)
|
||||
device_processor = DeviceProcessorStep(device="cuda", float_dtype="bfloat16")
|
||||
normalizer = NormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats=stats, dtype=torch.float32
|
||||
)
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
assert normalizer.dtype == torch.float32
|
||||
|
||||
# Create CPU input
|
||||
observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA
|
||||
processed_1 = device_processor(transition)
|
||||
intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"]
|
||||
assert intermediate_tensor.dtype == torch.bfloat16
|
||||
assert intermediate_tensor.device.type == "cuda"
|
||||
|
||||
# Step 2: NormalizerProcessor receives bfloat16 input and adapts
|
||||
final_result = normalizer(processed_1)
|
||||
final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"]
|
||||
|
||||
# Verify final output is bfloat16 (automatic adaptation worked)
|
||||
assert final_tensor.dtype == torch.bfloat16
|
||||
assert final_tensor.device.type == "cuda"
|
||||
|
||||
# Verify normalizer adapted its internal state
|
||||
assert normalizer.dtype == torch.bfloat16
|
||||
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
assert stat_tensor.device.type == "cuda"
|
||||
|
||||
|
||||
def test_stats_reconstruction_after_load_state_dict():
|
||||
"""
|
||||
Test that stats dict is properly reconstructed from _tensor_stats after loading.
|
||||
|
||||
This test ensures the bug where stats became empty after loading is fixed.
|
||||
The bug occurred when:
|
||||
1. Only _tensor_stats were saved via state_dict()
|
||||
2. stats field became empty {} after loading
|
||||
3. Calling to() method or hotswap_stats would fail because they depend on self.stats
|
||||
"""
|
||||
|
||||
# Create normalizer with stats
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"observation.state": {
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
"action": {
|
||||
"mean": np.array([0.0, 0.0]),
|
||||
"std": np.array([1.0, 2.0]),
|
||||
},
|
||||
}
|
||||
|
||||
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Save state dict (simulating save/load)
|
||||
state_dict = original_normalizer.state_dict()
|
||||
|
||||
# Create new normalizer with empty stats (simulating load)
|
||||
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
|
||||
|
||||
# Before fix: this would cause stats to remain empty
|
||||
new_normalizer.load_state_dict(state_dict)
|
||||
|
||||
# Verify that stats dict is properly reconstructed from _tensor_stats
|
||||
assert new_normalizer.stats is not None
|
||||
assert new_normalizer.stats != {}
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert "observation.image" in new_normalizer.stats
|
||||
assert "observation.state" in new_normalizer.stats
|
||||
assert "action" in new_normalizer.stats
|
||||
|
||||
# Check that values are correct (converted back from tensors)
|
||||
np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5])
|
||||
np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2])
|
||||
np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0])
|
||||
np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0])
|
||||
|
||||
# Test that methods that depend on self.stats work correctly after loading
|
||||
# This would fail before the bug fix because self.stats was empty
|
||||
|
||||
# Test 1: to() method should work without crashing
|
||||
try:
|
||||
new_normalizer.to(device="cpu", dtype=torch.float32)
|
||||
# If we reach here, the bug is fixed
|
||||
except (KeyError, AttributeError) as e:
|
||||
pytest.fail(f"to() method failed after loading state_dict: {e}")
|
||||
|
||||
# Test 2: hotswap_stats should work
|
||||
new_stats = {
|
||||
"observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]},
|
||||
"observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]},
|
||||
"action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]},
|
||||
}
|
||||
|
||||
pipeline = DataProcessorPipeline([new_normalizer])
|
||||
try:
|
||||
new_pipeline = hotswap_stats(pipeline, new_stats)
|
||||
# If we reach here, hotswap_stats worked correctly
|
||||
assert new_pipeline.steps[0].stats == new_stats
|
||||
except (KeyError, AttributeError) as e:
|
||||
pytest.fail(f"hotswap_stats failed after loading state_dict: {e}")
|
||||
|
||||
# Test 3: The normalizer should work functionally the same as the original
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
original_result = original_normalizer(transition)
|
||||
new_result = new_normalizer(transition)
|
||||
|
||||
# Results should be identical (within floating point precision)
|
||||
torch.testing.assert_close(
|
||||
original_result[TransitionKey.OBSERVATION]["observation.image"],
|
||||
new_result[TransitionKey.OBSERVATION]["observation.image"],
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
original_result[TransitionKey.OBSERVATION]["observation.state"],
|
||||
new_result[TransitionKey.OBSERVATION]["observation.state"],
|
||||
)
|
||||
torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION])
|
||||
|
||||
@@ -18,10 +18,9 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor import VanillaObservationProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.processor import TransitionKey, VanillaObservationProcessorStep
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -42,7 +41,7 @@ def create_transition(
|
||||
|
||||
def test_process_single_image():
|
||||
"""Test processing a single image."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create a mock image (H, W, C) format, uint8
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
@@ -68,7 +67,7 @@ def test_process_single_image():
|
||||
|
||||
def test_process_image_dict():
|
||||
"""Test processing multiple images in a dictionary."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create mock images
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -91,7 +90,7 @@ def test_process_image_dict():
|
||||
|
||||
def test_process_batched_image():
|
||||
"""Test processing already batched images."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create a batched image (B, H, W, C)
|
||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
||||
@@ -108,7 +107,7 @@ def test_process_batched_image():
|
||||
|
||||
def test_invalid_image_format():
|
||||
"""Test error handling for invalid image formats."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Test wrong channel order (channels first)
|
||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
||||
@@ -121,7 +120,7 @@ def test_invalid_image_format():
|
||||
|
||||
def test_invalid_image_dtype():
|
||||
"""Test error handling for invalid image dtype."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Test wrong dtype
|
||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
||||
@@ -134,7 +133,7 @@ def test_invalid_image_dtype():
|
||||
|
||||
def test_no_pixels_in_observation():
|
||||
"""Test processor when no pixels are in observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -149,7 +148,7 @@ def test_no_pixels_in_observation():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
@@ -159,7 +158,7 @@ def test_none_observation():
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test serialization methods."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
@@ -178,7 +177,7 @@ def test_serialization_methods():
|
||||
|
||||
def test_process_environment_state():
|
||||
"""Test processing environment_state."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
@@ -199,7 +198,7 @@ def test_process_environment_state():
|
||||
|
||||
def test_process_agent_pos():
|
||||
"""Test processing agent_pos."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
@@ -220,7 +219,7 @@ def test_process_agent_pos():
|
||||
|
||||
def test_process_batched_states():
|
||||
"""Test processing already batched states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
@@ -238,7 +237,7 @@ def test_process_batched_states():
|
||||
|
||||
def test_process_both_states():
|
||||
"""Test processing both environment_state and agent_pos."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||
@@ -263,7 +262,7 @@ def test_process_both_states():
|
||||
|
||||
def test_no_states_in_observation():
|
||||
"""Test processor when no states are in observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -277,7 +276,7 @@ def test_no_states_in_observation():
|
||||
|
||||
def test_complete_observation_processing():
|
||||
"""Test processing a complete observation with both images and states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create mock data
|
||||
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -314,7 +313,7 @@ def test_complete_observation_processing():
|
||||
|
||||
def test_image_only_processing():
|
||||
"""Test processing observation with only images."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
@@ -329,7 +328,7 @@ def test_image_only_processing():
|
||||
|
||||
def test_state_only_processing():
|
||||
"""Test processing observation with only states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
@@ -344,7 +343,7 @@ def test_state_only_processing():
|
||||
|
||||
def test_empty_observation():
|
||||
"""Test processing empty observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
observation = {}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -360,7 +359,7 @@ def test_equivalent_to_original_function():
|
||||
# Import the original function for comparison
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create test data similar to what the original function expects
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
@@ -387,7 +386,7 @@ def test_equivalent_with_image_dict():
|
||||
"""Test equivalence with dictionary of images."""
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create test data with multiple cameras
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -411,76 +410,132 @@ def test_equivalent_with_image_dict():
|
||||
|
||||
|
||||
def test_image_processor_features_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
|
||||
assert "pixels" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels"]
|
||||
)
|
||||
assert "pixels" not in out[PipelineFeatureType.OBSERVATION]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_features_observation_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
|
||||
assert "observation.pixels" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.pixels"]
|
||||
)
|
||||
assert "observation.pixels" not in out[PipelineFeatureType.OBSERVATION]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
|
||||
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
|
||||
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
|
||||
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
f"{OBS_IMAGES}.front" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels.front"]
|
||||
)
|
||||
assert (
|
||||
f"{OBS_IMAGES}.wrist" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.wrist"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels.wrist"]
|
||||
)
|
||||
assert (
|
||||
f"{OBS_IMAGES}.rear" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.rear"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.pixels.rear"]
|
||||
)
|
||||
assert (
|
||||
"pixels.front" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "pixels.wrist" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "observation.pixels.rear" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_features_environment_and_agent_pos(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
|
||||
assert "environment_state" not in out and "agent_pos" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["environment_state"]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["agent_pos"]
|
||||
)
|
||||
assert (
|
||||
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessor()
|
||||
proc = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
},
|
||||
}
|
||||
out = proc.transform_features(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
|
||||
assert "environment_state" not in out and "agent_pos" not in out
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.agent_pos"]
|
||||
)
|
||||
assert (
|
||||
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
@@ -25,13 +25,31 @@ from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
"""Mock tokenizer processor step for testing."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||
pass
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Pass through transition unchanged
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
# Pass through features unchanged
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
@@ -83,7 +101,7 @@ def test_make_pi0_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -92,22 +110,22 @@ def test_make_pi0_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], Pi0NewLineProcessor)
|
||||
# Step 4 would be TokenizerProcessor but it's mocked
|
||||
assert isinstance(preprocessor.steps[5], DeviceProcessor)
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor)
|
||||
# Step 3 would be TokenizerProcessorStep but it's mocked
|
||||
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_pi0_newline_processor_single_task():
|
||||
@@ -165,7 +183,7 @@ def test_pi0_processor_cuda():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -187,7 +205,7 @@ def test_pi0_processor_cuda():
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -220,7 +238,7 @@ def test_pi0_processor_accelerate_scenario():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -242,7 +260,7 @@ def test_pi0_processor_accelerate_scenario():
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -276,7 +294,7 @@ def test_pi0_processor_multi_gpu():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -298,7 +316,7 @@ def test_pi0_processor_multi_gpu():
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -329,7 +347,7 @@ def test_pi0_processor_without_stats():
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
@@ -359,3 +377,71 @@ def test_pi0_newline_processor_state_dict():
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert config == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
config.device = "cuda"
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, _ = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5)
|
||||
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert (
|
||||
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
|
||||
) # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
|
||||
+285
-303
File diff suppressed because it is too large
Load Diff
@@ -19,8 +19,13 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
@@ -46,7 +51,7 @@ def test_basic_renaming():
|
||||
"old_key1": "new_key1",
|
||||
"old_key2": "new_key2",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"old_key1": torch.tensor([1.0, 2.0]),
|
||||
@@ -74,7 +79,7 @@ def test_basic_renaming():
|
||||
|
||||
def test_empty_rename_map():
|
||||
"""Test processor with empty rename map (should pass through unchanged)."""
|
||||
processor = RenameProcessor(rename_map={})
|
||||
processor = RenameObservationsProcessorStep(rename_map={})
|
||||
|
||||
observation = {
|
||||
"key1": torch.tensor([1.0]),
|
||||
@@ -93,7 +98,7 @@ def test_empty_rename_map():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = RenameProcessor(rename_map={"old": "new"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
|
||||
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
@@ -108,7 +113,7 @@ def test_overlapping_rename():
|
||||
"a": "b",
|
||||
"b": "c", # This creates a potential conflict
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"a": 1,
|
||||
@@ -133,7 +138,7 @@ def test_partial_rename():
|
||||
"observation.state": "observation.proprio_state",
|
||||
"pixels": "observation.image",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
@@ -163,15 +168,15 @@ def test_get_config():
|
||||
"old1": "new1",
|
||||
"old2": "new2",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
config = processor.get_config()
|
||||
assert config == {"rename_map": rename_map}
|
||||
|
||||
|
||||
def test_state_dict():
|
||||
"""Test state dict (should be empty for RenameProcessor)."""
|
||||
processor = RenameProcessor(rename_map={"old": "new"})
|
||||
"""Test state dict (should be empty for RenameProcessorStep)."""
|
||||
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
|
||||
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
@@ -186,9 +191,9 @@ def test_integration_with_robot_processor():
|
||||
"agent_pos": "observation.state",
|
||||
"pixels": "observation.image",
|
||||
}
|
||||
rename_processor = RenameProcessor(rename_map=rename_map)
|
||||
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
pipeline = RobotProcessor([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = DataProcessorPipeline([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
observation = {
|
||||
"agent_pos": np.array([1.0, 2.0, 3.0]),
|
||||
@@ -220,32 +225,34 @@ def test_save_and_load_pretrained():
|
||||
"old_state": "observation.state",
|
||||
"old_image": "observation.image",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
pipeline = RobotProcessor([processor], name="TestRenameProcessor")
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check files were created
|
||||
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
|
||||
config_path = (
|
||||
Path(tmp_dir) / "testrenameprocessorstep.json"
|
||||
) # Based on name="TestRenameProcessorStep"
|
||||
assert config_path.exists()
|
||||
|
||||
# No state files should be created for RenameProcessor
|
||||
# No state files should be created for RenameProcessorStep
|
||||
state_files = list(Path(tmp_dir).glob("*.safetensors"))
|
||||
assert len(state_files) == 0
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
assert loaded_pipeline.name == "TestRenameProcessor"
|
||||
assert loaded_pipeline.name == "TestRenameProcessorStep"
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
# Check that loaded processor works correctly
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RenameProcessor)
|
||||
assert isinstance(loaded_processor, RenameObservationsProcessorStep)
|
||||
assert loaded_processor.rename_map == rename_map
|
||||
|
||||
# Test functionality after loading
|
||||
@@ -262,24 +269,24 @@ def test_save_and_load_pretrained():
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
"""Test that RenameProcessor is properly registered."""
|
||||
"""Test that RenameProcessorStep is properly registered."""
|
||||
# Check that it's registered
|
||||
assert "rename_processor" in ProcessorStepRegistry.list()
|
||||
assert "rename_observations_processor" in ProcessorStepRegistry.list()
|
||||
|
||||
# Get from registry
|
||||
retrieved_class = ProcessorStepRegistry.get("rename_processor")
|
||||
assert retrieved_class is RenameProcessor
|
||||
retrieved_class = ProcessorStepRegistry.get("rename_observations_processor")
|
||||
assert retrieved_class is RenameObservationsProcessorStep
|
||||
|
||||
# Create instance from registry
|
||||
instance = retrieved_class(rename_map={"old": "new"})
|
||||
assert isinstance(instance, RenameProcessor)
|
||||
assert isinstance(instance, RenameObservationsProcessorStep)
|
||||
assert instance.rename_map == {"old": "new"}
|
||||
|
||||
|
||||
def test_registry_based_save_load():
|
||||
"""Test save/load using registry name instead of module path."""
|
||||
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
processor = RenameObservationsProcessorStep(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = DataProcessorPipeline([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save and load
|
||||
@@ -288,24 +295,24 @@ def test_registry_based_save_load():
|
||||
# Verify config uses registry name
|
||||
import json
|
||||
|
||||
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
|
||||
with open(Path(tmp_dir) / "dataprocessorpipeline.json") as f: # Default name is "RobotProcessor"
|
||||
config = json.load(f)
|
||||
|
||||
assert "registry_name" in config["steps"][0]
|
||||
assert config["steps"][0]["registry_name"] == "rename_processor"
|
||||
assert config["steps"][0]["registry_name"] == "rename_observations_processor"
|
||||
assert "class" not in config["steps"][0] # Should use registry, not module path
|
||||
|
||||
# Load should work
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(tmp_dir)
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RenameProcessor)
|
||||
assert isinstance(loaded_processor, RenameObservationsProcessorStep)
|
||||
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
|
||||
|
||||
|
||||
def test_chained_rename_processors():
|
||||
"""Test multiple RenameProcessors in a pipeline."""
|
||||
"""Test multiple RenameProcessorSteps in a pipeline."""
|
||||
# First processor: rename raw keys to intermediate format
|
||||
processor1 = RenameProcessor(
|
||||
processor1 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"pos": "agent_position",
|
||||
"img": "camera_image",
|
||||
@@ -313,14 +320,16 @@ def test_chained_rename_processors():
|
||||
)
|
||||
|
||||
# Second processor: rename to final format
|
||||
processor2 = RenameProcessor(
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"agent_position": "observation.state",
|
||||
"camera_image": "observation.image",
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = RobotProcessor([processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = DataProcessorPipeline(
|
||||
[processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
observation = {
|
||||
"pos": np.array([1.0, 2.0]),
|
||||
@@ -356,7 +365,7 @@ def test_nested_observation_rename():
|
||||
"observation.images.right": "observation.camera.right_view",
|
||||
"observation.proprio": "observation.proprioception",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.images.left": torch.randn(3, 64, 64),
|
||||
@@ -386,7 +395,7 @@ def test_nested_observation_rename():
|
||||
def test_value_types_preserved():
|
||||
"""Test that various value types are preserved during renaming."""
|
||||
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
tensor_value = torch.randn(3, 3)
|
||||
array_value = np.random.rand(2, 2)
|
||||
@@ -414,59 +423,75 @@ def test_value_types_preserved():
|
||||
|
||||
|
||||
def test_features_basic_renaming(policy_feature_factory):
|
||||
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"a": "x", "b": "y"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
|
||||
"c": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"a": policy_feature_factory(FeatureType.VISUAL, (2,)),
|
||||
"b": policy_feature_factory(FeatureType.VISUAL, (3,)),
|
||||
"c": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
},
|
||||
}
|
||||
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
# Values preserved and typed
|
||||
assert out["x"] == features["a"]
|
||||
assert out["y"] == features["b"]
|
||||
assert out["c"] == features["c"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["x"] == features[PipelineFeatureType.OBSERVATION]["a"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["y"] == features[PipelineFeatureType.OBSERVATION]["b"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["c"]
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
# Input not mutated
|
||||
assert set(features) == {"a", "b", "c"}
|
||||
assert set(features[PipelineFeatureType.OBSERVATION]) == {"a", "b", "c"}
|
||||
|
||||
|
||||
def test_features_overlapping_keys(policy_feature_factory):
|
||||
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
|
||||
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"a": "b", "b": "c"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"b": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"a": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
"b": policy_feature_factory(FeatureType.VISUAL, (2,)),
|
||||
},
|
||||
}
|
||||
out = processor.transform_features(features)
|
||||
|
||||
assert set(out) == {"b", "c"}
|
||||
assert out["b"] == features["a"] # 'a' renamed to'b'
|
||||
assert out["c"] == features["b"] # 'b' renamed to 'c'
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"b", "c"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["b"] == features[PipelineFeatureType.OBSERVATION]["a"]
|
||||
) # 'a' renamed to'b'
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["b"]
|
||||
) # 'b' renamed to 'c'
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_features_chained_processors(policy_feature_factory):
|
||||
# Chain two rename processors at the contract level
|
||||
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameProcessor(
|
||||
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
||||
)
|
||||
pipeline = RobotProcessor([processor1, processor2])
|
||||
pipeline = DataProcessorPipeline([processor1, processor2])
|
||||
|
||||
spec = {
|
||||
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pos": policy_feature_factory(FeatureType.VISUAL, (7,)),
|
||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"extra": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
},
|
||||
}
|
||||
out = pipeline.transform_features(initial_features=spec)
|
||||
|
||||
assert set(out) == {"observation.state", "observation.image", "extra"}
|
||||
assert out["observation.state"] == spec["pos"]
|
||||
assert out["observation.image"] == spec["img"]
|
||||
assert out["extra"] == spec["extra"]
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.state"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
)
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.image"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
|
||||
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
@@ -86,20 +86,20 @@ def test_make_sac_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_sac_processor_normalization_modes():
|
||||
@@ -234,13 +234,13 @@ def test_sac_processor_without_stats():
|
||||
factory_preprocessor, factory_postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
preprocessor = DataProcessorPipeline(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
postprocessor = DataProcessorPipeline(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
@@ -277,7 +277,7 @@ def test_sac_processor_save_and_load():
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -306,10 +306,25 @@ def test_sac_processor_mixed_precision():
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
|
||||
@@ -374,3 +389,60 @@ def test_sac_processor_edge_cases():
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
# When action is None, it may still be present with None value
|
||||
assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(5, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
|
||||
@@ -20,7 +20,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.processor_smolvla import (
|
||||
@@ -28,13 +28,31 @@ from lerobot.policies.smolvla.processor_smolvla import (
|
||||
make_smolvla_pre_post_processors,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
"""Mock tokenizer processor step for testing."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||
pass
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Pass through transition unchanged
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
# Pass through features unchanged
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
@@ -88,7 +106,9 @@ def test_make_smolvla_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -97,22 +117,22 @@ def test_make_smolvla_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], SmolVLANewLineProcessor)
|
||||
# Step 4 would be TokenizerProcessor but it's mocked
|
||||
assert isinstance(preprocessor.steps[5], DeviceProcessor)
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor)
|
||||
# Step 3 would be TokenizerProcessorStep but it's mocked
|
||||
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_single_task():
|
||||
@@ -170,7 +190,7 @@ def test_smolvla_processor_cuda():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -192,7 +212,9 @@ def test_smolvla_processor_cuda():
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -225,7 +247,7 @@ def test_smolvla_processor_accelerate_scenario():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -247,7 +269,9 @@ def test_smolvla_processor_accelerate_scenario():
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -281,7 +305,7 @@ def test_smolvla_processor_multi_gpu():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessor:
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -303,7 +327,9 @@ def test_smolvla_processor_multi_gpu():
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
@@ -334,7 +360,9 @@ def test_smolvla_processor_without_stats():
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
@@ -372,7 +400,77 @@ def test_smolvla_newline_processor_transform_features():
|
||||
|
||||
# Test transform_features
|
||||
features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
}
|
||||
result = processor.transform_features(features)
|
||||
assert result == features # Should return unchanged
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_smolvla_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, _ = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration (SmolVLA has NormalizerProcessorStep at index 5)
|
||||
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert (
|
||||
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
|
||||
) # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
|
||||
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
@@ -89,20 +89,20 @@ def test_make_tdmpc_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_tdmpc_processor_normalization():
|
||||
@@ -251,13 +251,13 @@ def test_tdmpc_processor_without_stats():
|
||||
factory_preprocessor, factory_postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
preprocessor = DataProcessorPipeline(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
postprocessor = DataProcessorPipeline(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
@@ -297,7 +297,7 @@ def test_tdmpc_processor_save_and_load():
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -330,10 +330,25 @@ def test_tdmpc_processor_mixed_precision():
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -410,3 +425,67 @@ def test_tdmpc_processor_edge_cases():
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert OBS_STATE not in processed[TransitionKey.OBSERVATION]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tdmpc_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert (
|
||||
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
|
||||
) # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Tests for the TokenizerProcessor class.
|
||||
Tests for the TokenizerProcessorStep class.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
@@ -8,10 +8,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
|
||||
from lerobot.processor.tokenizer_processor import TokenizerProcessor
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -96,7 +95,7 @@ def test_basic_tokenization(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -128,7 +127,7 @@ def test_basic_tokenization_with_tokenizer_object():
|
||||
"""Test basic string tokenization functionality using tokenizer object directly."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -162,7 +161,7 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -190,7 +189,7 @@ def test_custom_keys(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -216,7 +215,7 @@ def test_none_complementary_data(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(complementary_data=None)
|
||||
|
||||
@@ -231,7 +230,7 @@ def test_missing_task_key(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(complementary_data={"other_field": "some value"})
|
||||
|
||||
@@ -246,7 +245,7 @@ def test_none_task_value(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(complementary_data={"task": None})
|
||||
|
||||
@@ -261,7 +260,7 @@ def test_unsupported_task_type(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
# Test with integer task
|
||||
transition = create_transition(complementary_data={"task": 123})
|
||||
@@ -280,7 +279,7 @@ def test_unsupported_task_type(mock_auto_tokenizer):
|
||||
def test_no_tokenizer_error():
|
||||
"""Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided."""
|
||||
with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"):
|
||||
TokenizerProcessor()
|
||||
TokenizerProcessorStep()
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@@ -291,7 +290,7 @@ def test_invalid_tokenizer_name_error():
|
||||
mock_auto_tokenizer.from_pretrained.side_effect = Exception("Model not found")
|
||||
|
||||
with pytest.raises(Exception, match="Model not found"):
|
||||
TokenizerProcessor(tokenizer_name="invalid-tokenizer")
|
||||
TokenizerProcessorStep(tokenizer_name="invalid-tokenizer")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@@ -301,7 +300,7 @@ def test_get_config_with_tokenizer_name(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(
|
||||
processor = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer",
|
||||
max_length=256,
|
||||
task_key="instruction",
|
||||
@@ -328,7 +327,7 @@ def test_get_config_with_tokenizer_object():
|
||||
"""Test configuration serialization when using tokenizer object."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
|
||||
processor = TokenizerProcessor(
|
||||
processor = TokenizerProcessorStep(
|
||||
tokenizer=mock_tokenizer,
|
||||
max_length=256,
|
||||
task_key="instruction",
|
||||
@@ -358,7 +357,7 @@ def test_state_dict_methods(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
# Should return empty dict
|
||||
state = processor.state_dict()
|
||||
@@ -375,7 +374,7 @@ def test_reset_method(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
# Should not raise error
|
||||
processor.reset()
|
||||
@@ -388,8 +387,10 @@ def test_integration_with_robot_processor(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
robot_processor = RobotProcessor([tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -423,18 +424,20 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
original_processor = TokenizerProcessor(
|
||||
original_processor = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer", max_length=32, task_key="instruction"
|
||||
)
|
||||
|
||||
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[original_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save processor
|
||||
robot_processor.save_pretrained(temp_dir)
|
||||
|
||||
# Load processor - tokenizer will be recreated from saved config
|
||||
loaded_processor = RobotProcessor.from_pretrained(
|
||||
loaded_processor = DataProcessorPipeline.from_pretrained(
|
||||
temp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -456,16 +459,20 @@ def test_save_and_load_pretrained_with_tokenizer_object():
|
||||
"""Test saving and loading processor with tokenizer object using overrides."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
|
||||
original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction")
|
||||
original_processor = TokenizerProcessorStep(
|
||||
tokenizer=mock_tokenizer, max_length=32, task_key="instruction"
|
||||
)
|
||||
|
||||
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[original_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save processor
|
||||
robot_processor.save_pretrained(temp_dir)
|
||||
|
||||
# Load processor with tokenizer override (since tokenizer object wasn't saved)
|
||||
loaded_processor = RobotProcessor.from_pretrained(
|
||||
loaded_processor = DataProcessorPipeline.from_pretrained(
|
||||
temp_dir,
|
||||
overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}},
|
||||
to_transition=lambda x: x,
|
||||
@@ -488,40 +495,44 @@ def test_save_and_load_pretrained_with_tokenizer_object():
|
||||
@require_package("transformers")
|
||||
def test_registry_functionality():
|
||||
"""Test that the processor is properly registered."""
|
||||
from lerobot.processor.pipeline import ProcessorStepRegistry
|
||||
from lerobot.processor import ProcessorStepRegistry
|
||||
|
||||
# Check that the processor is registered
|
||||
assert "tokenizer_processor" in ProcessorStepRegistry.list()
|
||||
|
||||
# Check that we can retrieve it
|
||||
retrieved_class = ProcessorStepRegistry.get("tokenizer_processor")
|
||||
assert retrieved_class is TokenizerProcessor
|
||||
assert retrieved_class is TokenizerProcessorStep
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_features_basic():
|
||||
"""Test basic feature contract functionality."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=128)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128)
|
||||
|
||||
input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that original features are preserved
|
||||
assert "observation.state" in output_features
|
||||
assert "action" in output_features
|
||||
assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert "action" in output_features[PipelineFeatureType.ACTION]
|
||||
|
||||
# Check that tokenized features are added
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in output_features
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
# Check feature properties
|
||||
tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][
|
||||
f"{OBS_LANGUAGE}.attention_mask"
|
||||
]
|
||||
|
||||
assert tokens_feature.type == FeatureType.LANGUAGE
|
||||
assert tokens_feature.shape == (128,)
|
||||
@@ -533,17 +544,19 @@ def test_features_basic():
|
||||
def test_features_with_custom_max_length():
|
||||
"""Test feature contract with custom max_length."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=64)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=64)
|
||||
|
||||
input_features = {}
|
||||
input_features = {PipelineFeatureType.OBSERVATION: {}}
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that features use correct max_length
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in output_features
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][
|
||||
f"{OBS_LANGUAGE}.attention_mask"
|
||||
]
|
||||
|
||||
assert tokens_feature.shape == (64,)
|
||||
assert attention_mask_feature.shape == (64,)
|
||||
@@ -553,18 +566,22 @@ def test_features_with_custom_max_length():
|
||||
def test_features_existing_features():
|
||||
"""Test feature contract when tokenized features already exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=256)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=256)
|
||||
|
||||
input_features = {
|
||||
f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
}
|
||||
}
|
||||
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Should not overwrite existing features
|
||||
assert output_features[f"{OBS_LANGUAGE}.tokens"].shape == (100,) # Original shape preserved
|
||||
assert output_features[f"{OBS_LANGUAGE}.attention_mask"].shape == (100,)
|
||||
assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"].shape == (
|
||||
100,
|
||||
) # Original shape preserved
|
||||
assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"].shape == (100,)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@@ -590,7 +607,7 @@ def test_tokenization_parameters(mock_auto_tokenizer):
|
||||
tracking_tokenizer = TrackingMockTokenizer()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(
|
||||
processor = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer",
|
||||
max_length=16,
|
||||
padding="longest",
|
||||
@@ -622,7 +639,7 @@ def test_preserves_other_complementary_data(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -657,7 +674,7 @@ def test_deterministic_tokenization(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -685,7 +702,7 @@ def test_empty_string_task(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -709,7 +726,7 @@ def test_very_long_task(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
|
||||
|
||||
long_task = " ".join(["word"] * 100) # Very long task
|
||||
transition = create_transition(
|
||||
@@ -759,7 +776,9 @@ def test_custom_padding_side(mock_auto_tokenizer):
|
||||
mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer
|
||||
|
||||
# Test left padding
|
||||
processor_left = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="left")
|
||||
processor_left = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer", max_length=10, padding_side="left"
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -771,7 +790,9 @@ def test_custom_padding_side(mock_auto_tokenizer):
|
||||
assert tracking_tokenizer.padding_side_calls[-1] == "left"
|
||||
|
||||
# Test right padding (default)
|
||||
processor_right = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="right")
|
||||
processor_right = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer", max_length=10, padding_side="right"
|
||||
)
|
||||
|
||||
processor_right(transition)
|
||||
|
||||
@@ -782,7 +803,7 @@ def test_custom_padding_side(mock_auto_tokenizer):
|
||||
def test_device_detection_cpu():
|
||||
"""Test that tokenized tensors stay on CPU when other tensors are on CPU."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10)} # CPU tensor
|
||||
@@ -806,7 +827,7 @@ def test_device_detection_cpu():
|
||||
def test_device_detection_cuda():
|
||||
"""Test that tokenized tensors are moved to CUDA when other tensors are on CUDA."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CUDA tensors
|
||||
observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor
|
||||
@@ -831,7 +852,7 @@ def test_device_detection_cuda():
|
||||
def test_device_detection_multi_gpu():
|
||||
"""Test that tokenized tensors match device in multi-GPU setup."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Test with tensors on cuda:1
|
||||
device = torch.device("cuda:1")
|
||||
@@ -855,7 +876,7 @@ def test_device_detection_multi_gpu():
|
||||
def test_device_detection_no_tensors():
|
||||
"""Test that tokenized tensors stay on CPU when no other tensors exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with no tensors
|
||||
transition = create_transition(
|
||||
@@ -877,7 +898,7 @@ def test_device_detection_no_tensors():
|
||||
def test_device_detection_mixed_devices():
|
||||
"""Test device detection when tensors are on different devices (uses first found)."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Create transition with mixed devices
|
||||
@@ -905,7 +926,7 @@ def test_device_detection_mixed_devices():
|
||||
def test_device_detection_from_action():
|
||||
"""Test that device is detected from action tensor when no observation tensors exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with action on CUDA but no observation tensors
|
||||
observation = {"metadata": {"key": "value"}} # No tensors in observation
|
||||
@@ -928,7 +949,7 @@ def test_device_detection_from_action():
|
||||
def test_device_detection_preserves_dtype():
|
||||
"""Test that device detection doesn't affect dtype of tokenized tensors."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with float tensor (to test dtype isn't affected)
|
||||
observation = {"observation.state": torch.randn(10, dtype=torch.float16)}
|
||||
@@ -948,16 +969,16 @@ def test_device_detection_preserves_dtype():
|
||||
@require_package("transformers")
|
||||
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
|
||||
def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
"""Test that TokenizerProcessor works correctly with DeviceProcessor in pipeline."""
|
||||
from lerobot.processor import DeviceProcessor
|
||||
"""Test that TokenizerProcessorStep works correctly with DeviceProcessorStep in pipeline."""
|
||||
from lerobot.processor import DeviceProcessorStep
|
||||
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Create pipeline with TokenizerProcessor then DeviceProcessor
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
device_processor = DeviceProcessor(device="cuda:0")
|
||||
robot_processor = RobotProcessor(
|
||||
# Create pipeline with TokenizerProcessorStep then DeviceProcessorStep
|
||||
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
|
||||
device_processor = DeviceProcessorStep(device="cuda:0")
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[tokenizer_processor, device_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -970,7 +991,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
|
||||
result = robot_processor(transition)
|
||||
|
||||
# All tensors should end up on CUDA (moved by DeviceProcessor)
|
||||
# All tensors should end up on CUDA (moved by DeviceProcessorStep)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
@@ -986,7 +1007,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
def test_simulated_accelerate_scenario():
|
||||
"""Test scenario simulating Accelerate with data already on GPU."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Simulate Accelerate scenario: batch already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
@@ -25,14 +25,14 @@ from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
UnnormalizerProcessor,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
@@ -89,20 +89,20 @@ def test_make_vqbet_processor_basic():
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessor)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessor)
|
||||
assert isinstance(preprocessor.steps[2], ToBatchProcessor)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessor)
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessor)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessor)
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_vqbet_processor_with_images():
|
||||
@@ -244,13 +244,13 @@ def test_vqbet_processor_without_stats():
|
||||
factory_preprocessor, factory_postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
preprocessor = DataProcessorPipeline(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
postprocessor = DataProcessorPipeline(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
@@ -290,7 +290,7 @@ def test_vqbet_processor_save_and_load():
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
@@ -323,10 +323,25 @@ def test_vqbet_processor_mixed_precision():
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -405,3 +420,68 @@ def test_vqbet_processor_sequential_processing():
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert result[TransitionKey.ACTION].shape == (1, 7)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_vqbet_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert (
|
||||
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
|
||||
) # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.robots.reachy2 import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Robot,
|
||||
Reachy2RobotConfig,
|
||||
)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_JOINTS = {
|
||||
**REACHY2_NECK_JOINTS,
|
||||
**REACHY2_ANTENNAS_JOINTS,
|
||||
**REACHY2_R_ARM_JOINTS,
|
||||
**REACHY2_L_ARM_JOINTS,
|
||||
}
|
||||
|
||||
PARAMS = [
|
||||
{}, # default config
|
||||
{"with_mobile_base": False},
|
||||
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
|
||||
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
|
||||
{"use_external_commands": True, "disable_torque_on_disconnect": True},
|
||||
{"use_external_commands": True, "with_mobile_base": False, "with_neck": False},
|
||||
{"disable_torque_on_disconnect": False},
|
||||
{"max_relative_target": 5},
|
||||
{"with_right_teleop_camera": False},
|
||||
{"with_left_teleop_camera": False, "with_right_teleop_camera": False},
|
||||
{"with_left_teleop_camera": False, "with_torso_camera": True},
|
||||
]
|
||||
|
||||
|
||||
def _make_reachy2_sdk_mock():
|
||||
class JointSpy:
|
||||
__slots__ = (
|
||||
"present_position",
|
||||
"_goal_position",
|
||||
"_on_set",
|
||||
)
|
||||
|
||||
def __init__(self, present_position=0.0, on_set=None):
|
||||
self.present_position = present_position
|
||||
self._goal_position = present_position
|
||||
self._on_set = on_set
|
||||
|
||||
@property
|
||||
def goal_position(self):
|
||||
return self._goal_position
|
||||
|
||||
@goal_position.setter
|
||||
def goal_position(self, v):
|
||||
self._goal_position = v
|
||||
if self._on_set:
|
||||
self._on_set()
|
||||
|
||||
r = MagicMock(name="ReachySDKMock")
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _connect():
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
r.is_connected.return_value = False
|
||||
|
||||
# Global counter of goal_position sets
|
||||
r._goal_position_set_total = 0
|
||||
|
||||
def _on_any_goal_set():
|
||||
r._goal_position_set_total += 1
|
||||
|
||||
# Mock joints with some dummy positions
|
||||
joints = {
|
||||
k: JointSpy(
|
||||
present_position=float(i),
|
||||
on_set=_on_any_goal_set,
|
||||
)
|
||||
for i, k in enumerate(REACHY2_JOINTS.values())
|
||||
}
|
||||
r.joints = joints
|
||||
|
||||
# Mock mobile base with some dummy odometry
|
||||
r.mobile_base = MagicMock()
|
||||
r.mobile_base.odometry = {
|
||||
"x": 0.1,
|
||||
"y": -0.2,
|
||||
"theta": 21.3,
|
||||
"vx": 0.001,
|
||||
"vy": 0.002,
|
||||
"vtheta": 0.0,
|
||||
}
|
||||
|
||||
r.connect = MagicMock(side_effect=_connect)
|
||||
r.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
# Mock methods
|
||||
r.turn_on = MagicMock()
|
||||
r.reset_default_limits = MagicMock()
|
||||
r.send_goal_positions = MagicMock()
|
||||
r.turn_off_smoothly = MagicMock()
|
||||
r.mobile_base.set_goal_speed = MagicMock()
|
||||
r.mobile_base.send_speed_command = MagicMock()
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _make_reachy2_camera_mock(*args, **kwargs):
|
||||
cfg = args[0] if args else kwargs.get("config")
|
||||
name = getattr(cfg, "name", kwargs.get("name", "cam"))
|
||||
image_type = getattr(cfg, "image_type", kwargs.get("image_type", "cam"))
|
||||
width = getattr(cfg, "width", kwargs.get("width", 640))
|
||||
height = getattr(cfg, "height", kwargs.get("height", 480))
|
||||
|
||||
cam = MagicMock(name=f"Reachy2CameraMock:{name}")
|
||||
cam.name = name
|
||||
cam.image_type = image_type
|
||||
cam.width = width
|
||||
cam.height = height
|
||||
cam.connect = MagicMock()
|
||||
cam.disconnect = MagicMock()
|
||||
cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8))
|
||||
return cam
|
||||
|
||||
|
||||
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
|
||||
def reachy2(request):
|
||||
with (
|
||||
patch(
|
||||
"lerobot.robots.reachy2.robot_reachy2.ReachySDK",
|
||||
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
|
||||
),
|
||||
patch(
|
||||
"lerobot.cameras.reachy2_camera.reachy2_camera.Reachy2Camera",
|
||||
side_effect=_make_reachy2_camera_mock,
|
||||
),
|
||||
):
|
||||
overrides = request.param
|
||||
cfg = Reachy2RobotConfig(ip_address="192.168.0.200", **overrides)
|
||||
robot = Reachy2Robot(cfg)
|
||||
yield robot
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_connect_disconnect(reachy2):
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.connect()
|
||||
assert reachy2.is_connected
|
||||
|
||||
reachy2.reachy.turn_on.assert_called_once()
|
||||
reachy2.reachy.reset_default_limits.assert_called_once()
|
||||
|
||||
reachy2.disconnect()
|
||||
assert not reachy2.is_connected
|
||||
|
||||
if reachy2.config.disable_torque_on_disconnect:
|
||||
reachy2.reachy.turn_off_smoothly.assert_called_once()
|
||||
else:
|
||||
reachy2.reachy.turn_off_smoothly.assert_not_called()
|
||||
reachy2.reachy.disconnect.assert_called_once()
|
||||
|
||||
|
||||
def test_get_joints_dict(reachy2):
|
||||
reachy2.connect()
|
||||
|
||||
if reachy2.config.with_neck:
|
||||
assert "neck_yaw.pos" in reachy2.joints_dict
|
||||
assert "neck_pitch.pos" in reachy2.joints_dict
|
||||
assert "neck_roll.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "neck_yaw.pos" not in reachy2.joints_dict
|
||||
assert "neck_pitch.pos" not in reachy2.joints_dict
|
||||
assert "neck_roll.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_antennas:
|
||||
assert "l_antenna.pos" in reachy2.joints_dict
|
||||
assert "r_antenna.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "l_antenna.pos" not in reachy2.joints_dict
|
||||
assert "r_antenna.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_r_arm:
|
||||
assert "r_shoulder_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_shoulder_roll.pos" in reachy2.joints_dict
|
||||
assert "r_elbow_yaw.pos" in reachy2.joints_dict
|
||||
assert "r_elbow_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_roll.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_pitch.pos" in reachy2.joints_dict
|
||||
assert "r_wrist_yaw.pos" in reachy2.joints_dict
|
||||
assert "r_gripper.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "r_shoulder_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_shoulder_roll.pos" not in reachy2.joints_dict
|
||||
assert "r_elbow_yaw.pos" not in reachy2.joints_dict
|
||||
assert "r_elbow_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_roll.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_pitch.pos" not in reachy2.joints_dict
|
||||
assert "r_wrist_yaw.pos" not in reachy2.joints_dict
|
||||
assert "r_gripper.pos" not in reachy2.joints_dict
|
||||
|
||||
if reachy2.config.with_l_arm:
|
||||
assert "l_shoulder_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_shoulder_roll.pos" in reachy2.joints_dict
|
||||
assert "l_elbow_yaw.pos" in reachy2.joints_dict
|
||||
assert "l_elbow_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_roll.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_pitch.pos" in reachy2.joints_dict
|
||||
assert "l_wrist_yaw.pos" in reachy2.joints_dict
|
||||
assert "l_gripper.pos" in reachy2.joints_dict
|
||||
else:
|
||||
assert "l_shoulder_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_shoulder_roll.pos" not in reachy2.joints_dict
|
||||
assert "l_elbow_yaw.pos" not in reachy2.joints_dict
|
||||
assert "l_elbow_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_roll.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_pitch.pos" not in reachy2.joints_dict
|
||||
assert "l_wrist_yaw.pos" not in reachy2.joints_dict
|
||||
assert "l_gripper.pos" not in reachy2.joints_dict
|
||||
|
||||
|
||||
def test_get_observation(reachy2):
|
||||
reachy2.connect()
|
||||
obs = reachy2.get_observation()
|
||||
|
||||
expected_keys = set(reachy2.joints_dict)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||
expected_keys.update(reachy2.cameras.keys())
|
||||
assert set(obs.keys()) == expected_keys
|
||||
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||
if reachy2.config.with_mobile_base:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||
if reachy2.config.with_left_teleop_camera:
|
||||
assert obs["teleop_left"].shape == (
|
||||
reachy2.config.cameras["teleop_left"].height,
|
||||
reachy2.config.cameras["teleop_left"].width,
|
||||
3,
|
||||
)
|
||||
if reachy2.config.with_right_teleop_camera:
|
||||
assert obs["teleop_right"].shape == (
|
||||
reachy2.config.cameras["teleop_right"].height,
|
||||
reachy2.config.cameras["teleop_right"].width,
|
||||
3,
|
||||
)
|
||||
if reachy2.config.with_torso_camera:
|
||||
assert obs["torso_rgb"].shape == (
|
||||
reachy2.config.cameras["torso_rgb"].height,
|
||||
reachy2.config.cameras["torso_rgb"].width,
|
||||
3,
|
||||
)
|
||||
|
||||
|
||||
def test_send_action(reachy2):
|
||||
reachy2.connect()
|
||||
|
||||
action = {k: i * 10.0 for i, k in enumerate(reachy2.joints_dict.keys(), start=1)}
|
||||
if reachy2.config.with_mobile_base:
|
||||
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
|
||||
|
||||
previous_present_position = {
|
||||
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
|
||||
}
|
||||
returned = reachy2.send_action(action)
|
||||
|
||||
if reachy2.config.max_relative_target is None:
|
||||
assert returned == action
|
||||
|
||||
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
expected_pos = action[motor]
|
||||
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||
if reachy2.config.max_relative_target is None:
|
||||
assert real_pos == expected_pos
|
||||
else:
|
||||
assert real_pos == previous_present_position[motor] + np.sign(expected_pos) * min(
|
||||
abs(expected_pos - real_pos), reachy2.config.max_relative_target
|
||||
)
|
||||
|
||||
if reachy2.config.with_mobile_base:
|
||||
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
|
||||
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)
|
||||
|
||||
if reachy2.config.use_external_commands:
|
||||
reachy2.reachy.send_goal_positions.assert_not_called()
|
||||
if reachy2.config.with_mobile_base:
|
||||
reachy2.reachy.mobile_base.send_speed_command.assert_not_called()
|
||||
else:
|
||||
reachy2.reachy.send_goal_positions.assert_called_once()
|
||||
if reachy2.config.with_mobile_base:
|
||||
reachy2.reachy.mobile_base.send_speed_command.assert_called_once()
|
||||
|
||||
|
||||
def test_no_part_declared():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2RobotConfig(
|
||||
ip_address="192.168.0.200",
|
||||
with_mobile_base=False,
|
||||
with_l_arm=False,
|
||||
with_r_arm=False,
|
||||
with_neck=False,
|
||||
with_antennas=False,
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.teleoperators.reachy2_teleoperator import (
|
||||
REACHY2_ANTENNAS_JOINTS,
|
||||
REACHY2_L_ARM_JOINTS,
|
||||
REACHY2_NECK_JOINTS,
|
||||
REACHY2_R_ARM_JOINTS,
|
||||
REACHY2_VEL,
|
||||
Reachy2Teleoperator,
|
||||
Reachy2TeleoperatorConfig,
|
||||
)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_JOINTS = {
|
||||
**REACHY2_NECK_JOINTS,
|
||||
**REACHY2_ANTENNAS_JOINTS,
|
||||
**REACHY2_R_ARM_JOINTS,
|
||||
**REACHY2_L_ARM_JOINTS,
|
||||
}
|
||||
|
||||
PARAMS = [
|
||||
{}, # default config
|
||||
{"with_mobile_base": False},
|
||||
{"with_mobile_base": False, "with_l_arm": False, "with_antennas": False},
|
||||
{"with_r_arm": False, "with_neck": False, "with_antennas": False},
|
||||
{"with_mobile_base": False, "with_neck": False},
|
||||
{"use_present_position": True},
|
||||
]
|
||||
|
||||
|
||||
def _make_reachy2_sdk_mock():
|
||||
r = MagicMock(name="ReachySDKMock")
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _connect():
|
||||
r.is_connected.return_value = True
|
||||
|
||||
def _disconnect():
|
||||
r.is_connected.return_value = False
|
||||
|
||||
# Mock joints with some dummy positions
|
||||
joints = {
|
||||
k: MagicMock(
|
||||
present_position=float(i),
|
||||
goal_position=float(i) + 0.5,
|
||||
)
|
||||
for i, k in enumerate(REACHY2_JOINTS.values())
|
||||
}
|
||||
r.joints = joints
|
||||
|
||||
# Mock mobile base with some dummy odometry
|
||||
r.mobile_base = MagicMock()
|
||||
r.mobile_base.last_cmd_vel = {
|
||||
"vx": -0.2,
|
||||
"vy": 0.2,
|
||||
"vtheta": 11.0,
|
||||
}
|
||||
r.mobile_base.odometry = {
|
||||
"x": 1.0,
|
||||
"y": 2.0,
|
||||
"theta": 20.0,
|
||||
"vx": 0.1,
|
||||
"vy": -0.1,
|
||||
"vtheta": 8.0,
|
||||
}
|
||||
|
||||
r.connect = MagicMock(side_effect=_connect)
|
||||
r.disconnect = MagicMock(side_effect=_disconnect)
|
||||
|
||||
return r
|
||||
|
||||
|
||||
@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys()))
|
||||
def reachy2(request):
|
||||
with (
|
||||
patch(
|
||||
"lerobot.teleoperators.reachy2_teleoperator.reachy2_teleoperator.ReachySDK",
|
||||
side_effect=lambda *a, **k: _make_reachy2_sdk_mock(),
|
||||
),
|
||||
):
|
||||
overrides = request.param
|
||||
cfg = Reachy2TeleoperatorConfig(ip_address="192.168.0.200", **overrides)
|
||||
robot = Reachy2Teleoperator(cfg)
|
||||
yield robot
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_connect_disconnect(reachy2):
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.connect()
|
||||
assert reachy2.is_connected
|
||||
|
||||
reachy2.disconnect()
|
||||
assert not reachy2.is_connected
|
||||
|
||||
reachy2.reachy.disconnect.assert_called_once()
|
||||
|
||||
|
||||
def test_get_action(reachy2):
|
||||
reachy2.connect()
|
||||
action = reachy2.get_action()
|
||||
|
||||
expected_keys = set(reachy2.joints_dict)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||
assert set(action.keys()) == expected_keys
|
||||
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
if reachy2.config.use_present_position:
|
||||
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||
else:
|
||||
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||
if reachy2.config.with_mobile_base:
|
||||
if reachy2.config.use_present_position:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||
else:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
|
||||
|
||||
|
||||
def test_no_part_declared():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Reachy2TeleoperatorConfig(
|
||||
ip_address="192.168.0.200",
|
||||
with_mobile_base=False,
|
||||
with_l_arm=False,
|
||||
with_r_arm=False,
|
||||
with_neck=False,
|
||||
with_antennas=False,
|
||||
)
|
||||
@@ -5,7 +5,7 @@ from types import SimpleNamespace
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.processor import TransitionKey
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -86,7 +86,10 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
TransitionKey.ACTION: act,
|
||||
}
|
||||
|
||||
vu.log_rerun_data(transition)
|
||||
# Extract observation and action data from transition like in the real call sites
|
||||
obs_data = transition.get(TransitionKey.OBSERVATION, {})
|
||||
action_data = transition.get(TransitionKey.ACTION, {})
|
||||
vu.log_rerun_data(observation=obs_data, action=action_data)
|
||||
|
||||
# We expect:
|
||||
# - observation.state.temperature -> Scalar
|
||||
@@ -141,7 +144,9 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
"vec": np.array([9, 8, 7], dtype=np.float32),
|
||||
}
|
||||
|
||||
vu.log_rerun_data([obs_plain, act_plain])
|
||||
# Extract observation and action data from list like the old function logic did
|
||||
# First dict was treated as observation, second as action
|
||||
vu.log_rerun_data(observation=obs_plain, action=act_plain)
|
||||
|
||||
# Expected keys with auto-prefixes
|
||||
expected = {
|
||||
@@ -181,7 +186,6 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
|
||||
vu.log_rerun_data(
|
||||
None,
|
||||
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
|
||||
action={"action.a": 1.0},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user