mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 15:19:43 +00:00
Compare commits
136 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f247aa0701 | |||
| 1ac6a6d3fe | |||
| e698c709d8 | |||
| a988da4789 | |||
| 99963b6968 | |||
| 332ca4ccc5 | |||
| fc43246942 | |||
| 793ad86fc9 | |||
| a6dbb65917 | |||
| 6c7169c4af | |||
| f125d5e3bf | |||
| 75dcfd4886 | |||
| ff3cbaa872 | |||
| ce793cde64 | |||
| 029c4a9a76 | |||
| d893bf1e30 | |||
| 8c796b39f5 | |||
| 4ebe482a7e | |||
| 2fcc358e98 | |||
| b052843f08 | |||
| ebb464c255 | |||
| 2914ae2a96 | |||
| 645c87e3a9 | |||
| 2c802ac134 | |||
| 15ffc01fb3 | |||
| a837685bf8 | |||
| d32b76cc66 | |||
| 08fb310eaa | |||
| 574a708950 | |||
| ce665160ae | |||
| 35c5d43255 | |||
| 95c1e32aa5 | |||
| e4db65a127 | |||
| 0053defa2e | |||
| fd5d8b3d5f | |||
| 5bf82f8229 | |||
| 5ca3920611 | |||
| 8bde9d0ab7 | |||
| abcbc16126 | |||
| e4fd30a8d4 | |||
| 5f759b1637 | |||
| 6a75b4761a | |||
| e5ade5565d | |||
| 0524551f52 | |||
| 862bc7ef85 | |||
| d38792d6e5 | |||
| db3cf0158c | |||
| 0535f2a59a | |||
| 2805ae347c | |||
| 28ef6fcd14 | |||
| 7fc7ec75bb | |||
| 87890cbf38 | |||
| 5326ffe77e | |||
| a1734cf575 | |||
| 82f300e880 | |||
| 3e7c9d7afc | |||
| e9cb779eab | |||
| 8ff95be04c | |||
| f02ce69df0 | |||
| 1feb7b5d88 | |||
| fbe9009db2 | |||
| c0013b130b | |||
| c4763f61a1 | |||
| b95c219d96 | |||
| 9b1138171e | |||
| 023b8f3466 | |||
| 1cad87ebd2 | |||
| 99de7567e6 | |||
| 21baa8fa02 | |||
| 8b4a5368b3 | |||
| f5c6b03b61 | |||
| e7be2fd113 | |||
| b632490b4b | |||
| 9a9c7208d2 | |||
| 427b97d198 | |||
| 2c2bb1e8bf | |||
| 4b24f94225 | |||
| 670a278cbc | |||
| fc74001202 | |||
| f14ac5d486 | |||
| 7bd0d62ce5 | |||
| 7eccefe235 | |||
| b72274066e | |||
| 20f2910b63 | |||
| fd4ae3466b | |||
| 7beb040e8e | |||
| 05bd18f453 | |||
| 8077456c00 | |||
| 5595887fd0 | |||
| 41959389b6 | |||
| 2c4e888c7f | |||
| 5ced72e6b8 | |||
| 907023f9f7 | |||
| 4ba23ea029 | |||
| 409ac0baca | |||
| 699363f9fc | |||
| ae7a54de57 | |||
| fb9139b882 | |||
| 9fe3a3fb17 | |||
| 26cb9a24c3 | |||
| 77106697c3 | |||
| 75bc44c166 | |||
| f2b79656eb | |||
| 14c2ece004 | |||
| 35612c61e1 | |||
| f7bb3e2d90 | |||
| 1e0d667a22 | |||
| 33969a0337 | |||
| fa26290e8c | |||
| e9f7f5127b | |||
| 097842c70f | |||
| 3b8a3a32a0 | |||
| 1c56779dd9 | |||
| 83a4338f8b | |||
| 730c7b2f35 | |||
| 116059a43e | |||
| b08149a113 | |||
| c227107f60 | |||
| 01dc289f3d | |||
| 6830ca7645 | |||
| ed42c71fc3 | |||
| e0139065bd | |||
| e509f255af | |||
| e2fcd140b0 | |||
| 2a7a0e6129 | |||
| 9f33791b19 | |||
| 453e0a995f | |||
| 8ebf79c494 | |||
| 8774aec304 | |||
| ac742c9f0d | |||
| cd13f1ecfd | |||
| 9aa632968f | |||
| 62caaf07b0 | |||
| 3355f04ca6 | |||
| 769f531603 | |||
| f6c7287ae7 |
+382
-56
@@ -4,7 +4,13 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
|
||||
|
||||
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
|
||||
|
||||
It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
It combines three key ingredients:
|
||||
|
||||
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
|
||||
|
||||
2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
|
||||
|
||||
3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
|
||||
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
|
||||
|
||||
@@ -56,30 +62,243 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class GymManipulatorConfig:
|
||||
env: HILSerlRobotEnvConfig # Environment configuration (nested)
|
||||
dataset: DatasetConfig # Dataset recording/replay configuration (nested)
|
||||
mode: str | None = None # "record", "replay", or None (for training)
|
||||
device: str = "cpu" # Compute device
|
||||
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
|
||||
wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
|
||||
fps: int = 10 # Control frequency
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
|
||||
processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
|
||||
name: str = "real_robot" # Environment name
|
||||
mode: str = None # "record", "replay", or None (for training)
|
||||
repo_id: str | None = None # LeRobot dataset repository ID
|
||||
dataset_root: str | None = None # Local dataset root (optional)
|
||||
task: str = "" # Task identifier
|
||||
num_episodes: int = 10 # Number of episodes for recording
|
||||
episode: int = 0 # episode index for replay
|
||||
device: str = "cuda" # Compute device
|
||||
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
|
||||
pretrained_policy_name_or_path: str | None = None # For policy loading
|
||||
reward_classifier_pretrained_path: str | None = None # For reward model
|
||||
number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
|
||||
task: str | None = None # Task identifier
|
||||
fps: int = 10 # Control frequency
|
||||
|
||||
# Nested processor configuration
|
||||
class HILSerlProcessorConfig:
|
||||
control_mode: str = "gamepad" # Control mode
|
||||
observation: ObservationConfig | None = None # Observation processing settings
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
|
||||
gripper: GripperConfig | None = None # Gripper control and penalty settings
|
||||
reset: ResetConfig | None = None # Environment reset and timing settings
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
|
||||
reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
|
||||
max_gripper_pos: float | None = 100.0 # Maximum gripper position
|
||||
|
||||
# Sub-configuration classes
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
|
||||
resize_size: tuple[int, int] | None = None # Target image size
|
||||
|
||||
class GripperConfig:
|
||||
use_gripper: bool = True # Enable gripper control
|
||||
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
|
||||
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
|
||||
|
||||
class ResetConfig:
|
||||
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
|
||||
reset_time_s: float = 5.0 # Time to wait during reset
|
||||
control_time_s: float = 20.0 # Maximum episode duration
|
||||
terminate_on_success: bool = True # Whether to terminate episodes on success detection
|
||||
|
||||
class InverseKinematicsConfig:
|
||||
urdf_path: str | None = None # Path to robot URDF file
|
||||
target_frame_name: str | None = None # End-effector frame name
|
||||
end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
|
||||
end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
|
||||
|
||||
class RewardClassifierConfig:
|
||||
pretrained_path: str | None = None # Path to pretrained reward classifier
|
||||
success_threshold: float = 0.5 # Success detection threshold
|
||||
success_reward: float = 1.0 # Reward value for successful episodes
|
||||
|
||||
# Dataset configuration
|
||||
class DatasetConfig:
|
||||
repo_id: str # LeRobot dataset repository ID
|
||||
task: str # Task identifier
|
||||
root: str | None = None # Local dataset root directory
|
||||
num_episodes_to_record: int = 5 # Number of episodes for recording
|
||||
replay_episode: int | None = None # Episode index for replay
|
||||
push_to_hub: bool = False # Whether to push datasets to Hub
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Processor Pipeline Architecture
|
||||
|
||||
HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
|
||||
|
||||
#### Environment Processor Pipeline
|
||||
|
||||
The environment processor (`env_processor`) handles incoming observations and environment state:
|
||||
|
||||
1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format
|
||||
2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations
|
||||
3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations
|
||||
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
|
||||
5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images
|
||||
6. **TimeLimitProcessorStep** (optional): Enforces episode time limits
|
||||
7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage
|
||||
8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models
|
||||
9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing
|
||||
10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU)
|
||||
|
||||
#### Action Processor Pipeline
|
||||
|
||||
The action processor (`action_processor`) handles outgoing actions and human interventions:
|
||||
|
||||
1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging
|
||||
2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals
|
||||
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
|
||||
4. **InterventionActionProcessorStep**: Handles human interventions and episode termination
|
||||
5. **Inverse Kinematics Pipeline** (when enabled):
|
||||
- **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format
|
||||
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
|
||||
- **EEBoundsAndSafety**: Enforces workspace safety bounds
|
||||
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
|
||||
- **GripperVelocityToJoint**: Handles gripper control commands
|
||||
|
||||
#### Configuration Examples
|
||||
|
||||
**Basic Observation Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Image Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.front": [180, 250, 120, 150],
|
||||
"observation.images.side": [180, 207, 180, 200]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Inverse Kinematics Setup**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"urdf_path": "path/to/robot.urdf",
|
||||
"target_frame_name": "end_effector",
|
||||
"end_effector_bounds": {
|
||||
"min": [0.16, -0.08, 0.03],
|
||||
"max": [0.24, 0.2, 0.1]
|
||||
},
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Observation Processing
|
||||
|
||||
The HIL-SERL framework supports additional observation processing features that can improve policy learning:
|
||||
|
||||
#### Joint Velocity Processing
|
||||
|
||||
Enable joint velocity estimation to provide the policy with motion information:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Estimates joint velocities using finite differences between consecutive joint position readings
|
||||
- Adds velocity information to the observation state vector
|
||||
- Useful for policies that need motion awareness for dynamic tasks
|
||||
|
||||
#### Motor Current Processing
|
||||
|
||||
Monitor motor currents to detect contact forces and load conditions:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_current_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Reads motor current values from the robot's control system
|
||||
- Adds current measurements to the observation state vector
|
||||
- Helps detect contact events, object weights, and mechanical resistance
|
||||
- Useful for contact-rich manipulation tasks
|
||||
|
||||
#### Combined Observation Processing
|
||||
|
||||
You can enable multiple observation processing features simultaneously:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": true,
|
||||
"add_ee_pose_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
|
||||
|
||||
### Finding Robot Workspace Bounds
|
||||
|
||||
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
|
||||
@@ -130,22 +349,56 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
|
||||
|
||||
Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
|
||||
|
||||
1. Set `mode` to `"record"`
|
||||
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
|
||||
3. Set `num_episodes` to the number of demonstrations you want to collect
|
||||
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
|
||||
5. Configure `robot`, `cameras`, and other hardware settings
|
||||
1. Set `mode` to `"record"` at the root level
|
||||
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
|
||||
3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect
|
||||
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
|
||||
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
|
||||
|
||||
Example configuration section:
|
||||
|
||||
```json
|
||||
"mode": "record",
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"observation": {
|
||||
"display_cameras": false
|
||||
},
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {},
|
||||
"resize_size": [128, 128]
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": 0.0
|
||||
},
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes_to_record": 15,
|
||||
"replay_episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cpu"
|
||||
}
|
||||
```
|
||||
|
||||
### Using a Teleoperation Device
|
||||
@@ -191,10 +444,20 @@ The gamepad provides a very convenient way to control the robot and the episode
|
||||
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
@@ -216,11 +479,21 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
|
||||
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921", # check your port number
|
||||
"use_degrees": true
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921",
|
||||
"use_degrees": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "leader",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
|
||||
@@ -251,7 +524,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e
|
||||
|
||||
During recording:
|
||||
|
||||
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
|
||||
1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
|
||||
2. Complete the task successfully
|
||||
3. The episode ends with a reward of 1 when you press the "success" button
|
||||
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
|
||||
@@ -310,11 +583,19 @@ observation.images.front: [180, 250, 120, 150]
|
||||
Add these crop parameters to your training configuration:
|
||||
|
||||
```json
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Recommended image resolution**
|
||||
@@ -343,26 +624,52 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
|
||||
|
||||
**Key Parameters for Data Collection**
|
||||
|
||||
- **mode**: set it to `"record"` to collect a dataset
|
||||
- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **num_episodes**: Number of episodes to record
|
||||
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
|
||||
- **fps**: Number of frames per second to record
|
||||
- **push_to_hub**: Whether to push the dataset to the hub
|
||||
- **mode**: set it to `"record"` to collect a dataset (at root level)
|
||||
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **dataset.num_episodes_to_record**: Number of episodes to record
|
||||
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
|
||||
- **env.fps**: Number of frames per second to record
|
||||
- **dataset.push_to_hub**: Whether to push the dataset to the hub
|
||||
|
||||
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
|
||||
The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
|
||||
|
||||
**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
|
||||
|
||||
Example configuration section for data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0,
|
||||
"terminate_on_success": false
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes_to_record": 20,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"num_episodes": 20,
|
||||
"push_to_hub": true,
|
||||
"fps": 10,
|
||||
"number_of_steps_after_success": 15
|
||||
"device": "cpu"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -421,9 +728,17 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
env_config = HILSerlRobotEnvConfig(
|
||||
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
|
||||
# Other environment parameters
|
||||
config = GymManipulatorConfig(
|
||||
env=HILSerlRobotEnvConfig(
|
||||
processor=HILSerlProcessorConfig(
|
||||
reward_classifier=RewardClassifierConfig(
|
||||
pretrained_path="path_to_your_pretrained_trained_model"
|
||||
)
|
||||
),
|
||||
# Other environment parameters
|
||||
),
|
||||
dataset=DatasetConfig(...),
|
||||
mode=None # For training
|
||||
)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
@@ -432,7 +747,18 @@ or set the argument in the json config file.
|
||||
|
||||
```json
|
||||
{
|
||||
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
|
||||
"env": {
|
||||
"processor": {
|
||||
"reward_classifier": {
|
||||
"pretrained_path": "path_to_your_pretrained_model",
|
||||
"success_threshold": 0.7,
|
||||
"success_reward": 1.0
|
||||
},
|
||||
"reset": {
|
||||
"terminate_on_success": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
+56
-30
@@ -32,9 +32,12 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hil",
|
||||
"name": "franka_sim",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
@@ -45,28 +48,40 @@ Available tasks:
|
||||
- `PandaPickCubeGamepad-v0`: With gamepad control
|
||||
- `PandaPickCubeKeyboard-v0`: With keyboard control
|
||||
|
||||
### Gym Wrappers Configuration
|
||||
### Processor Configuration
|
||||
|
||||
```json
|
||||
"wrapper": {
|
||||
"gripper_penalty": -0.02,
|
||||
"control_time_s": 15.0,
|
||||
"use_gripper": true,
|
||||
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
},
|
||||
"control_mode": "gamepad"
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": -0.02
|
||||
},
|
||||
"reset": {
|
||||
"control_time_s": 15.0,
|
||||
"fixed_reset_joint_positions": [
|
||||
0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
|
||||
]
|
||||
},
|
||||
"inverse_kinematics": {
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Important parameters:
|
||||
|
||||
- `gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `use_gripper`: Whether to enable gripper control
|
||||
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `gripper.gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `gripper.use_gripper`: Whether to enable gripper control
|
||||
- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
|
||||
|
||||
## Running with HIL RL of LeRobot
|
||||
@@ -75,39 +90,50 @@ Important parameters:
|
||||
|
||||
To run the environment, set mode to null:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Recording a Dataset
|
||||
|
||||
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0"
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/sim_dataset",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 10,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Training a Policy
|
||||
|
||||
To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
In a different terminal, run the learner server:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
|
||||
|
||||
|
||||
@@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
@@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig(
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/eval_<dataset_repo_id>",
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -559,6 +562,12 @@ _init_rerun(session_name="recording")
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
@@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES):
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
+53
-5
@@ -24,11 +24,36 @@ pip install -e ".[hilserl]"
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
|
||||
|
||||
To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
|
||||
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
|
||||
|
||||
If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 30,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
|
||||
Then we can run this command to start:
|
||||
|
||||
@@ -140,9 +165,32 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
|
||||
|
||||
## Evaluate your policy in Sim
|
||||
|
||||
To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
|
||||
Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
|
||||
Here's an example evaluation configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube"
|
||||
},
|
||||
"pretrained_policy_name_or_path": "your_username/il_sim_model",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to replace:
|
||||
|
||||
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
|
||||
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
|
||||
|
||||
Then you can run this command to visualize your trained policy
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
@@ -11,12 +12,14 @@ NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -25,7 +28,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<eval_dataset_repo_id>",
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -43,6 +46,12 @@ listener, events = init_keyboard_listener()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
@@ -53,6 +62,8 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -38,7 +38,7 @@ while True:
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
log_rerun_data(observation, {**arm_action, **base_action})
|
||||
log_rerun_data(observation=observation, action={**arm_action, **base_action})
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot with degrees
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset action and gripper features
|
||||
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
||||
) # Get all ee action features + gripper pos action features
|
||||
|
||||
# Build dataset observation features
|
||||
obs_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
) # Get all ee observation features
|
||||
|
||||
dataset_features = combine_feature_dicts(obs_ee, action_ee_and_gripper)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -0,0 +1,215 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 10
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset ee action features
|
||||
action_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=phone.action_features,
|
||||
use_videos=True,
|
||||
patterns=["action.ee"],
|
||||
)
|
||||
|
||||
# Get gripper pos action features
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||
)
|
||||
|
||||
# Build dataset ee observation features
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
)
|
||||
|
||||
dataset_features = combine_feature_dicts(action_ee, gripper, observation_ee)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
dataset.push_to_hub()
|
||||
@@ -0,0 +1,81 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
robot_ee_to_joints_processor.reset()
|
||||
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
}
|
||||
|
||||
joint_action = robot_ee_to_joints_processor(ee_action)
|
||||
action_sent = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specif
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot.")
|
||||
while True:
|
||||
# Get teleop observation
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor(phone_obs)
|
||||
|
||||
if joint_action:
|
||||
robot.send_action(joint_action)
|
||||
|
||||
time.sleep(0.01)
|
||||
+4
-2
@@ -95,7 +95,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
|
||||
transformers-dep = ["transformers<=4.52.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
|
||||
# Motors
|
||||
@@ -111,6 +111,7 @@ intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
@@ -152,7 +153,8 @@ all = [
|
||||
"lerobot[video_benchmark]",
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[xarm]"
|
||||
"lerobot[xarm]",
|
||||
"lerobot[phone]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@@ -24,6 +24,7 @@ class FeatureType(str, Enum):
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
LANGUAGE = "LANGUAGE"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
@@ -21,8 +21,14 @@ OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
|
||||
|
||||
ROBOTS = "robots"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
@@ -39,6 +45,9 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor"
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor"
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: DataProcessorPipeline,
|
||||
initial_features: dict[str, Any],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates the pipeline's features and returns a features dict ready for the dataset,
|
||||
filtered to only those keys matching any of the given patterns (for action/state only).
|
||||
|
||||
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
|
||||
- `use_videos`: whether to treat image features as video streams
|
||||
- `patterns`: regexes to filter action & state features; images are included
|
||||
whenever use_videos=True, regardless of patterns.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Gather everything the pipeline features specifies, seeded with hardware cams:
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Helper to decide which action/state keys survive the `patterns` filter:
|
||||
def keep(key: str) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
# Start with hardware dict, injecting initial cameras if videos are ON:
|
||||
hw: dict[str, dict[str, Any]] = {}
|
||||
if use_videos:
|
||||
cams = {
|
||||
name: shape
|
||||
for name, shape in initial_features.items()
|
||||
if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
if cams:
|
||||
hw["observation"] = dict(cams)
|
||||
|
||||
# Go over every feature from the pipeline and merge:
|
||||
for full_key, ty in all_features.items():
|
||||
if full_key.startswith(f"{ACTION}."):
|
||||
# action.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{ACTION}.") :]
|
||||
hw.setdefault(ACTION, {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_STATE}."):
|
||||
# observation.state.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{OBS_STATE}.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_IMAGES}."):
|
||||
# observation.images.<cam>
|
||||
# images obey ONLY the use_videos flag, not patterns
|
||||
if not use_videos:
|
||||
continue
|
||||
name = full_key[len(f"{OBS_IMAGES}.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
else:
|
||||
# anything else (e.g. policy-only features) is ignored here
|
||||
continue
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
if ACTION in hw:
|
||||
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
|
||||
if "observation" in hw:
|
||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||
|
||||
return out
|
||||
@@ -470,6 +470,50 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
return policy_features
|
||||
|
||||
|
||||
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||
"""
|
||||
Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (observation.images.*), last one wins (if they are identical).
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
for key, value in d.items():
|
||||
if not isinstance(value, dict):
|
||||
out[key] = value
|
||||
continue
|
||||
|
||||
dtype = value.get("dtype")
|
||||
shape = value.get("shape")
|
||||
is_vector = (
|
||||
dtype not in ("image", "video", "string")
|
||||
and isinstance(shape, tuple)
|
||||
and len(shape) == 1
|
||||
and "names" in value
|
||||
)
|
||||
|
||||
if is_vector:
|
||||
# Initialize or retrieve the accumulating dict for this feature key
|
||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||
# Ensure consistent data types across merged entries
|
||||
if "dtype" in target and dtype != target["dtype"]:
|
||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||
|
||||
# Merge feature names: append only new ones to preserve order without duplicates
|
||||
seen = set(target["names"])
|
||||
for n in value["names"]:
|
||||
if n not in seen:
|
||||
target["names"].append(n)
|
||||
seen.add(n)
|
||||
# Recompute the shape to reflect the updated number of features
|
||||
target["shape"] = (len(target["names"]),)
|
||||
else:
|
||||
# For images/videos and non-1D entries: override with the latest definition
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
|
||||
+57
-86
@@ -161,35 +161,73 @@ class XarmEnv(EnvConfig):
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
class RewardClassifierConfig:
|
||||
"""Configuration for reward classification."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationConfig:
|
||||
"""Configuration for observation processing."""
|
||||
|
||||
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
control_mode: str = "gamepad"
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GripperConfig:
|
||||
"""Configuration for gripper control and penalties."""
|
||||
|
||||
use_gripper: bool = True
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetConfig:
|
||||
"""Configuration for environment reset behavior."""
|
||||
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
control_time_s: float = 20.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlProcessorConfig:
|
||||
"""Configuration for environment processing pipeline."""
|
||||
|
||||
control_mode: str = "gamepad"
|
||||
observation: ObservationConfig | None = None
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None
|
||||
gripper: GripperConfig | None = None
|
||||
reset: ResetConfig | None = None
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None
|
||||
reward_classifier: RewardClassifierConfig | None = None
|
||||
max_gripper_pos: float | None = 100.0
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
@@ -197,77 +235,10 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
|
||||
robot: RobotConfig | None = None
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
fps: int = 10
|
||||
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
|
||||
|
||||
name: str = "real_robot"
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
task: str | None = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
name: str = "PandaPickCube"
|
||||
task: str | None = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
robot_config: RobotConfig | None = None
|
||||
teleop_config: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
@@ -15,6 +15,17 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0.processor_pi0 import Pi0NewLineProcessor
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"PI0Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
]
|
||||
|
||||
@@ -35,7 +35,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
@@ -51,27 +50,16 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
@@ -137,23 +125,19 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
actions = self.model(batch)[0]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_act_pre_post_processors(
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -35,7 +35,6 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
@@ -57,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -70,14 +68,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
@@ -106,9 +96,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -137,7 +124,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
@@ -153,11 +139,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -14,12 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
from torch import nn
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import torch
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
@@ -34,9 +39,10 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
@@ -101,6 +107,163 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the processor config."""
|
||||
|
||||
preprocessor_config_filename: str | None
|
||||
postprocessor_config_filename: str | None
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
preprocessor_kwargs: ProcessorKwargs | None
|
||||
postprocessor_kwargs: ProcessorKwargs | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Make a processor instance for a given policy type.
|
||||
|
||||
This function creates the appropriate processor configuration based on the policy type.
|
||||
Each policy type has its own processor with specific preprocessing steps.
|
||||
|
||||
Args:
|
||||
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||
the processor from this path instead of creating a new one.
|
||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||
|
||||
Returns:
|
||||
Tuple of (input_processor, output_processor) for the policy.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Extract preprocessor and postprocessor kwargs
|
||||
preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {})
|
||||
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=preprocessor_kwargs.get("to_transition"),
|
||||
to_output=preprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=postprocessor_kwargs.get("to_transition"),
|
||||
to_output=postprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new processor based on policy type
|
||||
if isinstance(policy_cfg, TDMPCConfig):
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
|
||||
|
||||
processors = make_tdmpc_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, DiffusionConfig):
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
|
||||
|
||||
processors = make_diffusion_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, ACTConfig):
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
|
||||
processors = make_act_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
|
||||
processors = make_vqbet_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
||||
|
||||
processors = make_pi0_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0FASTConfig):
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
|
||||
|
||||
processors = make_pi0fast_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
processors = make_smolvla_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
return processors
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
@@ -147,7 +310,6 @@ def make_policy(
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
@@ -155,6 +317,8 @@ def make_policy(
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
if env_cfg is None:
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
@@ -171,7 +335,7 @@ def make_policy(
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
|
||||
@@ -1,420 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model."
|
||||
)
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
@@ -56,18 +56,15 @@ from collections import deque
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
from lerobot.utils.utils import get_safe_dtype, init_logging
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
@@ -223,28 +220,17 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
@@ -253,99 +239,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
||||
"""
|
||||
Transform state dict keys to match expected model structure.
|
||||
|
||||
Transformations:
|
||||
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
||||
model.paligemma_with_expert.paligemma.lm_head
|
||||
- model.paligemma_with_expert.paligemma.language_model.model ->
|
||||
model.paligemma_with_expert.paligemma.model.language_model
|
||||
- model.paligemma_with_expert.paligemma.vision_tower ->
|
||||
model.paligemma_with_expert.paligemma.model.vision_tower
|
||||
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
||||
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
||||
|
||||
Also handles tied weights between lm_head.weight and
|
||||
embed_tokens.weight.
|
||||
"""
|
||||
import re
|
||||
|
||||
transformed_dict = {}
|
||||
|
||||
transformations = [
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
||||
".paligemma_with_expert.paligemma.lm_head",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
||||
".paligemma_with_expert.paligemma.model.language_model",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
||||
".paligemma_with_expert.paligemma.model.vision_tower",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
||||
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
),
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
for pattern, replacement in transformations:
|
||||
new_key = pattern.sub(replacement, new_key)
|
||||
transformed_dict[new_key] = value
|
||||
|
||||
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
||||
lm_head_key = None
|
||||
embed_tokens_key = None
|
||||
|
||||
for key in transformed_dict:
|
||||
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
||||
lm_head_key = key
|
||||
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
||||
embed_tokens_key = key
|
||||
if lm_head_key and embed_tokens_key:
|
||||
break
|
||||
|
||||
if lm_head_key and not embed_tokens_key:
|
||||
embed_tokens_key = lm_head_key.replace(
|
||||
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
||||
)
|
||||
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
||||
elif embed_tokens_key and not lm_head_key:
|
||||
lm_head_key = embed_tokens_key.replace(
|
||||
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
||||
)
|
||||
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
||||
|
||||
return transformed_dict
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
||||
) -> "PI0Policy":
|
||||
"""Override to apply key transformations before loading."""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
init_logging()
|
||||
# Load the state dict from file safely
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
|
||||
# Apply key transformations
|
||||
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Load the transformed state dict
|
||||
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
||||
|
||||
# Log message
|
||||
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -377,14 +270,13 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
@@ -394,8 +286,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -410,12 +300,10 @@ class PI0Policy(PreTrainedPolicy):
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
@@ -482,26 +370,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
@@ -567,7 +435,7 @@ class PI0FlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PI0Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one.
|
||||
This is required for the PaliGemma tokenizer.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def make_pi0_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -58,7 +58,6 @@ from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
@@ -146,14 +145,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
@@ -221,8 +212,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
@@ -235,8 +224,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -249,8 +236,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_pi0fast_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -28,7 +28,6 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.policies.normalize import NormalizeBuffer
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
@@ -45,7 +44,6 @@ class SACPolicy(
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -53,7 +51,6 @@ class SACPolicy(
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
self._init_normalization(dataset_stats)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
@@ -88,8 +85,7 @@ class SACPolicy(
|
||||
|
||||
observations_features = None
|
||||
if self.shared_encoder and self.actor.encoder.has_images:
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
@@ -391,28 +387,12 @@ class SACPolicy(
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_normalization(self, dataset_stats):
|
||||
"""Initialize input/output normalization modules."""
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
if self.config.dataset_stats is not None:
|
||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||
self.normalize_inputs = NormalizeBuffer(
|
||||
self.config.input_features, self.config.normalization_mapping, params
|
||||
)
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = NormalizeBuffer(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic
|
||||
if self.shared_encoder
|
||||
else SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
@@ -424,9 +404,7 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
@@ -434,9 +412,7 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
@@ -490,10 +466,9 @@ class SACPolicy(
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
@@ -568,11 +543,10 @@ class SACObservationEncoder(nn.Module):
|
||||
def forward(
|
||||
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
cache = self.get_cached_image_features(obs)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
@@ -585,7 +559,7 @@ class SACObservationEncoder(nn.Module):
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
@@ -597,26 +571,17 @@ class SACObservationEncoder(nn.Module):
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in select_action()
|
||||
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward() with normalize=False
|
||||
- Called internally by forward()
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
if normalize:
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
@@ -747,7 +712,6 @@ class CriticEnsemble(nn.Module):
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
output_normalization (nn.Module): normalization layer for actions.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
@@ -757,13 +721,11 @@ class CriticEnsemble(nn.Module):
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
@@ -775,11 +737,6 @@ class CriticEnsemble(nn.Module):
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# NOTE: We normalize actions it helps for sample efficiency
|
||||
actions: dict[str, torch.tensor] = {"action": actions}
|
||||
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -20,7 +20,6 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
@@ -108,22 +107,12 @@ class Classifier(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
@@ -247,10 +236,6 @@ class Classifier(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
NormalizerProcessorStep(
|
||||
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
NormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name="classifier_preprocessor",
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name="classifier_postprocessor",
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -53,21 +53,13 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
@@ -76,102 +68,6 @@ from lerobot.policies.utils import (
|
||||
)
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
|
||||
def canonicalise(k: str) -> str:
|
||||
"""
|
||||
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||
normalisation-buffer key.
|
||||
"""
|
||||
return _VARIANT_RE.sub(".buffer_", k)
|
||||
|
||||
|
||||
def standardise_state_dict(
|
||||
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
"""
|
||||
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||
• If several variant keys collapse to the same canonical name we keep the
|
||||
first one and log the collision.
|
||||
• Returns the new dict + a list of entries that could not be matched.
|
||||
"""
|
||||
out, collisions, unmatched = {}, {}, []
|
||||
|
||||
for k, v in checkpoint.items():
|
||||
canon = canonicalise(k)
|
||||
if canon in ref_keys:
|
||||
if canon in out: # duplicate after collapsing
|
||||
collisions.setdefault(canon, []).append(k)
|
||||
else:
|
||||
out[canon] = v
|
||||
else:
|
||||
unmatched.append(k)
|
||||
|
||||
if verbose:
|
||||
for canon, variants in collisions.items():
|
||||
print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||
if unmatched:
|
||||
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||
|
||||
out.update({k: checkpoint[k] for k in unmatched})
|
||||
return out, unmatched
|
||||
|
||||
|
||||
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
"""
|
||||
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||
|
||||
Args:
|
||||
checkpoint (dict): The checkpoint dictionary.
|
||||
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||
|
||||
Returns:
|
||||
dict: The modified checkpoint with renamed keys.
|
||||
"""
|
||||
|
||||
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||
|
||||
new_checkpoint = {}
|
||||
for k, v in checkpoint.items():
|
||||
for old_key, new_key in rename_dict.items():
|
||||
if old_key in k:
|
||||
k = k.replace(old_key, new_key)
|
||||
new_checkpoint[k] = v
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def load_smolvla(
|
||||
model: torch.nn.Module,
|
||||
filename: str | os.PathLike,
|
||||
*,
|
||||
device: str = "cpu",
|
||||
checkpoint_keys_mapping: str = "",
|
||||
) -> torch.nn.Module:
|
||||
state_dict = safetensors.torch.load_file(filename, device=device)
|
||||
|
||||
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||
|
||||
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||
|
||||
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||
raise RuntimeError(
|
||||
"SmolVLA %d missing / %d unexpected keys",
|
||||
len(missing),
|
||||
len(unexpected),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
@@ -326,28 +222,17 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.reset()
|
||||
|
||||
@@ -357,23 +242,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls,
|
||||
model: "SmolVLAPolicy",
|
||||
model_file: str,
|
||||
map_location: str,
|
||||
strict: bool,
|
||||
):
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return load_smolvla(
|
||||
model,
|
||||
model_file,
|
||||
device=map_location,
|
||||
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||
)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -389,7 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
|
||||
@@ -397,8 +266,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -408,8 +275,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -450,11 +315,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
@@ -518,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStepRegistry,
|
||||
RenameProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_smolvla_pre_post_processors(
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
SmolVLANewLineProcessor(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
padding=config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=config.tokenizer_max_length,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -36,7 +36,6 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
@@ -63,26 +62,19 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
@@ -137,7 +129,6 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -147,11 +138,12 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -320,11 +312,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_tdmpc_pre_post_processors(
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -28,7 +28,6 @@ import torchvision
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -48,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -61,14 +59,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.reset()
|
||||
@@ -128,7 +118,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -142,10 +131,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -165,10 +156,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
# loss: total loss of training RVQ
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_vqbet_pre_post_processors(
|
||||
config: VQBeTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -14,41 +14,90 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .device_processor import DeviceProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
merge_transitions,
|
||||
transition_to_batch,
|
||||
transition_to_dataset_frame,
|
||||
)
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .gym_action_processor import Numpy2TorchActionProcessorStep, Torch2NumpyActionProcessorStep
|
||||
from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessorStep
|
||||
from .pipeline import (
|
||||
ActionProcessor,
|
||||
DoneProcessor,
|
||||
EnvTransition,
|
||||
IdentityProcessor,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DoneProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
InfoProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
RewardProcessorStep,
|
||||
RobotProcessorPipeline,
|
||||
TruncatedProcessorStep,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
from .rename_processor import RenameProcessorStep
|
||||
from .tokenizer_processor import TokenizerProcessorStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessor",
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
"AddTeleopEventsAsInfoStep",
|
||||
"ComplementaryDataProcessorStep",
|
||||
"batch_to_transition",
|
||||
"create_transition",
|
||||
"DeviceProcessorStep",
|
||||
"DoneProcessorStep",
|
||||
"EnvTransition",
|
||||
"IdentityProcessor",
|
||||
"InfoProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"ObservationProcessor",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"JointVelocityProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"merge_transitions",
|
||||
"MotorCurrentProcessorStep",
|
||||
"NormalizerProcessorStep",
|
||||
"Numpy2TorchActionProcessorStep",
|
||||
"ObservationProcessorStep",
|
||||
"PolicyProcessorPipeline",
|
||||
"ProcessorKwargs",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"RenameProcessorStep",
|
||||
"RewardClassifierProcessorStep",
|
||||
"RewardProcessorStep",
|
||||
"DataProcessorPipeline",
|
||||
"TimeLimitProcessorStep",
|
||||
"AddBatchDimensionProcessorStep",
|
||||
"RobotProcessorPipeline",
|
||||
"TokenizerProcessorStep",
|
||||
"Torch2NumpyActionProcessorStep",
|
||||
"transition_to_batch",
|
||||
"transition_to_dataset_frame",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
"TruncatedProcessorStep",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .core import EnvTransition
|
||||
from .pipeline import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
||||
class AddBatchDimensionActionStep(ActionProcessorStep):
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
|
||||
def action(self, action):
|
||||
if not isinstance(action, Tensor) or action.dim() != 1:
|
||||
return action
|
||||
|
||||
return action.unsqueeze(0)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||
class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
|
||||
def observation(self, observation):
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
state_value = observation[state_key]
|
||||
if isinstance(state_value, Tensor) and state_value.dim() == 1:
|
||||
observation[state_key] = state_value.unsqueeze(0)
|
||||
|
||||
# Process single image observation - add batch dim if 3D
|
||||
if OBS_IMAGE in observation:
|
||||
image_value = observation[OBS_IMAGE]
|
||||
if isinstance(image_value, Tensor) and image_value.dim() == 3:
|
||||
observation[OBS_IMAGE] = image_value.unsqueeze(0)
|
||||
|
||||
# Process multiple image observations - add batch dim if 3D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
return observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||
class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
if isinstance(task_value, str):
|
||||
complementary_data["task"] = [task_value]
|
||||
|
||||
# Process index field - add batch dim if 0D
|
||||
if "index" in complementary_data:
|
||||
index_value = complementary_data["index"]
|
||||
if isinstance(index_value, Tensor) and index_value.dim() == 0:
|
||||
complementary_data["index"] = index_value.unsqueeze(0)
|
||||
|
||||
# Process task_index field - add batch dim if 0D
|
||||
if "task_index" in complementary_data:
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
return complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class AddBatchDimensionProcessorStep(ProcessorStep):
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
|
||||
- For state observations (observation.state, observation.environment_state):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For image observations (observation.image, observation.images.*):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
||||
|
||||
- For actions:
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For task field in complementary data:
|
||||
Wraps string task in a list to add batch dimension
|
||||
(task must be a string or list of strings)
|
||||
|
||||
This is useful when processing single transitions that need to be batched for
|
||||
model inference or when converting from unbatched environment outputs to
|
||||
batched model inputs.
|
||||
|
||||
The processor only modifies tensors that need batching and leaves already
|
||||
batched tensors unchanged.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# State: (7,) -> (1, 7)
|
||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||
# Action: (4,) -> (1, 4)
|
||||
# Task: "pick_cube" -> ["pick_cube"]
|
||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||
```
|
||||
"""
|
||||
|
||||
to_batch_action_processor: AddBatchDimensionActionStep = field(
|
||||
default_factory=AddBatchDimensionActionStep
|
||||
)
|
||||
to_batch_observation_processor: AddBatchDimensionObservationStep = field(
|
||||
default_factory=AddBatchDimensionObservationStep
|
||||
)
|
||||
to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field(
|
||||
default_factory=AddBatchDimensionComplementaryDataStep
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = self.to_batch_action_processor(transition)
|
||||
transition = self.to_batch_observation_processor(transition)
|
||||
transition = self.to_batch_complementary_data_processor(transition)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# NOTE: We ignore the batch dimension when transforming features
|
||||
return features
|
||||
@@ -0,0 +1,481 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@singledispatch
|
||||
def to_tensor(
|
||||
value: Any,
|
||||
*,
|
||||
dtype: torch.dtype | None = torch.float32,
|
||||
device: torch.device | str | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert various data types to PyTorch tensors with configurable options.
|
||||
|
||||
This is a unified tensor conversion function using single dispatch to handle
|
||||
different input types appropriately.
|
||||
|
||||
Args:
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.)
|
||||
dtype: Target tensor dtype. If None, preserves original dtype.
|
||||
device: Target device for the tensor.
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input type is not supported.
|
||||
"""
|
||||
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
|
||||
|
||||
|
||||
@to_tensor.register(torch.Tensor)
|
||||
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle existing PyTorch tensors."""
|
||||
if dtype is not None:
|
||||
value = value.to(dtype=dtype)
|
||||
if device is not None:
|
||||
value = value.to(device=device)
|
||||
return value
|
||||
|
||||
|
||||
@to_tensor.register(np.ndarray)
|
||||
def _(
|
||||
value: np.ndarray,
|
||||
*,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Handle numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
|
||||
if value.ndim == 0:
|
||||
# Numpy scalars should be converted to 0-dimensional tensors
|
||||
scalar_value = value.item()
|
||||
return torch.tensor(scalar_value, dtype=dtype, device=device)
|
||||
|
||||
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
|
||||
tensor = torch.from_numpy(value)
|
||||
|
||||
# Apply dtype conversion if specified
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if device is not None:
|
||||
tensor = tensor.to(device=device)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@to_tensor.register(int)
|
||||
@to_tensor.register(float)
|
||||
@to_tensor.register(np.integer)
|
||||
@to_tensor.register(np.floating)
|
||||
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle scalar values including numpy scalars."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(list)
|
||||
@to_tensor.register(tuple)
|
||||
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle sequences (lists, tuples)."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(dict)
|
||||
def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
"""Handle dictionaries by recursively converting values to tensors."""
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
for key, sub_value in value.items():
|
||||
if sub_value is None:
|
||||
continue
|
||||
|
||||
if isinstance(sub_value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert individual values to tensors
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
|
||||
"""Convert tensor to numpy/scalar if needed."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
|
||||
def _is_image(arr: Any) -> bool:
|
||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
||||
|
||||
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if "image" in k.lower() or _is_image(v):
|
||||
images[k] = v
|
||||
else:
|
||||
state[k] = v
|
||||
return state, images
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Private Helper Functions (Common Logic)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract complementary data (pad flags, task, index, task_index)."""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
|
||||
|
||||
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
"""Merge two transitions, with other taking precedence."""
|
||||
out = deepcopy(base)
|
||||
|
||||
for key in (
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.INFO,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
):
|
||||
if other.get(key):
|
||||
out.setdefault(key, {}).update(deepcopy(other[key]))
|
||||
|
||||
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
|
||||
if k in other:
|
||||
out[k] = other[k]
|
||||
return out
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Core Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
reward: float = 0.0,
|
||||
done: bool = False,
|
||||
truncated: bool = False,
|
||||
info: dict[str, Any] | None = None,
|
||||
complementary_data: dict[str, Any] | None = None,
|
||||
) -> EnvTransition:
|
||||
"""Create an EnvTransition with sensible defaults.
|
||||
|
||||
Args:
|
||||
observation: Observation dictionary.
|
||||
action: Action dictionary.
|
||||
reward: Scalar reward value.
|
||||
done: Episode termination flag.
|
||||
truncated: Episode truncation flag.
|
||||
info: Additional info dictionary.
|
||||
complementary_data: Complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
Complete EnvTransition dictionary.
|
||||
"""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info if info is not None else {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_transition
|
||||
"""
|
||||
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
|
||||
"""
|
||||
act_dict: dict[str, Any] = {}
|
||||
for k, v in action.items():
|
||||
# Check if the value is a type that should not be converted to a tensor.
|
||||
if isinstance(v, (Rotation, dict)):
|
||||
act_dict[f"{ACTION}.{k}"] = v
|
||||
continue
|
||||
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
|
||||
|
||||
return create_transition(observation={}, action=act_dict)
|
||||
|
||||
|
||||
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
|
||||
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
for k, v in state.items():
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
|
||||
|
||||
for cam, img in images.items():
|
||||
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
|
||||
|
||||
return create_transition(observation=obs_dict, action={})
|
||||
|
||||
|
||||
def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
action_dict = transition.get(TransitionKey.ACTION) or {}
|
||||
|
||||
if action_dict is None:
|
||||
return out
|
||||
|
||||
for k, v in action_dict.items():
|
||||
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
|
||||
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
|
||||
out[out_key] = float(v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
|
||||
"""Merge multiple transitions or return single transition.
|
||||
|
||||
Args:
|
||||
transitions: Either a single transition or iterable of transitions.
|
||||
|
||||
Returns:
|
||||
Merged EnvTransition.
|
||||
"""
|
||||
|
||||
if not isinstance(transitions, Sequence): # Single transition
|
||||
return transitions
|
||||
|
||||
items = list(transitions)
|
||||
if not items:
|
||||
raise ValueError("merge_transitions() requires a non-empty sequence of transitions")
|
||||
|
||||
result = items[0]
|
||||
for t in items[1:]:
|
||||
result = _merge_transitions(result, t)
|
||||
return result
|
||||
|
||||
|
||||
def transition_to_dataset_frame(
|
||||
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation.
|
||||
|
||||
Processes transitions according to the provided feature specification and returns
|
||||
data in the format expected by machine learning models and datasets.
|
||||
|
||||
Args:
|
||||
transitions_or_transition: Either a single EnvTransition dict or an iterable of them
|
||||
(which will be merged using merge_transitions).
|
||||
features: Feature specification dictionary with the following structure:
|
||||
- 'action': dict with 'names': list of action feature names
|
||||
- 'observation.state': dict with 'names': list of state feature names
|
||||
- keys starting with 'observation.images.' are passed through as-is
|
||||
|
||||
Returns:
|
||||
Flat dictionary containing:
|
||||
- numpy arrays for "observation.state" and "action" (vectorized from feature names)
|
||||
- any image tensors defined in features (passed through unchanged)
|
||||
- next.{reward,done,truncated} scalar values
|
||||
- info dict
|
||||
- *_is_pad flags and task from complementary_data
|
||||
"""
|
||||
action_names = features.get(ACTION, {}).get("names", [])
|
||||
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
|
||||
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
|
||||
|
||||
tr = merge_transitions(transitions_or_transition)
|
||||
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
act = tr.get(TransitionKey.ACTION, {}) or {}
|
||||
batch: dict[str, Any] = {}
|
||||
|
||||
# Images passthrough
|
||||
for k in image_keys:
|
||||
if k in obs:
|
||||
batch[k] = obs[k]
|
||||
|
||||
# Observation.state vector
|
||||
if obs_state_names:
|
||||
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
|
||||
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Action vector
|
||||
if action_names:
|
||||
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
|
||||
batch[ACTION] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Add transition metadata
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
reward_val = _from_tensor(tr[TransitionKey.REWARD])
|
||||
# Check if features expect array format, otherwise keep as scalar
|
||||
if REWARD in features and features[REWARD].get("shape") == (1,):
|
||||
batch[REWARD] = np.array([reward_val], dtype=np.float32)
|
||||
else:
|
||||
batch[REWARD] = reward_val
|
||||
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
done_val = _from_tensor(tr[TransitionKey.DONE])
|
||||
if DONE in features and features[DONE].get("shape") == (1,):
|
||||
batch[DONE] = np.array([done_val], dtype=bool)
|
||||
else:
|
||||
batch[DONE] = done_val
|
||||
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
|
||||
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
|
||||
else:
|
||||
batch[TRUNCATED] = truncated_val
|
||||
|
||||
# Complementary data flags and task
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if comp:
|
||||
# pad flags
|
||||
for k, v in comp.items():
|
||||
if k.endswith("_is_pad"):
|
||||
batch[k] = v
|
||||
# task label
|
||||
if comp.get("task") is not None:
|
||||
batch["task"] = comp["task"]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
"""Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary.
|
||||
|
||||
The function maps well known keys to the EnvTransition structure. Missing keys are
|
||||
filled with sane defaults (None or 0.0/False).
|
||||
|
||||
Keys recognised (case-sensitive):
|
||||
* "observation.*" (keys starting with "observation." are grouped into observation dict)
|
||||
* "action"
|
||||
* "next.reward"
|
||||
* "next.done"
|
||||
* "next.truncated"
|
||||
* "info"
|
||||
* "_is_pad" patterns (padding flags)
|
||||
* "task", "index", "task_index" (complementary data)
|
||||
|
||||
Additional keys are ignored so that existing dataloaders can carry extra
|
||||
metadata without breaking the processor.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from datasets or dataloaders containing the above keys.
|
||||
|
||||
Returns:
|
||||
EnvTransition dictionary with properly structured transition data.
|
||||
"""
|
||||
|
||||
# Validate input type
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
# Extract observation keys
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
complementary_data = _extract_complementary_data(batch)
|
||||
|
||||
return create_transition(
|
||||
observation=observation_keys if observation_keys else None,
|
||||
action=batch.get("action"),
|
||||
reward=batch.get("next.reward", 0.0),
|
||||
done=batch.get("next.done", False),
|
||||
truncated=batch.get("next.truncated", False),
|
||||
info=batch.get("info", {}),
|
||||
complementary_data=complementary_data if complementary_data else None,
|
||||
)
|
||||
|
||||
|
||||
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot.
|
||||
|
||||
Converts an EnvTransition back to the batch format expected by datasets, dataloaders,
|
||||
and other LeRobot components.
|
||||
|
||||
Output format:
|
||||
* "action": Action data from transition
|
||||
* "next.reward": Reward value (defaults to 0.0)
|
||||
* "next.done": Done flag (defaults to False)
|
||||
* "next.truncated": Truncated flag (defaults to False)
|
||||
* "info": Info dictionary (defaults to {})
|
||||
* Flattened observation keys (e.g., "observation.state", "observation.images.cam1")
|
||||
* Complementary data fields ("task", "index", "task_index", padding flags)
|
||||
|
||||
Args:
|
||||
transition: EnvTransition dictionary to convert.
|
||||
|
||||
Returns:
|
||||
Batch dictionary with canonical LeRobot field names suitable for dataloaders.
|
||||
"""
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add complementary data
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if comp_data:
|
||||
batch.update(comp_data)
|
||||
|
||||
# Flatten observation dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
return batch
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
# TODO(Steven): Use consts
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
DONE = "done"
|
||||
TRUNCATED = "truncated"
|
||||
INFO = "info"
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
"EnvTransition",
|
||||
{
|
||||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||
TransitionKey.INFO.value: dict[str, Any] | None,
|
||||
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,145 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@dataclass
|
||||
class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"""
|
||||
Map a tensor to a delta action dictionary.
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
|
||||
def action(self, action: Tensor) -> dict:
|
||||
if action.dim() > 1:
|
||||
action = action.squeeze(0)
|
||||
|
||||
# TODO (maractingi): add rotation
|
||||
delta_action = {
|
||||
f"{ACTION}.delta_x": action[0],
|
||||
f"{ACTION}.delta_y": action[1],
|
||||
f"{ACTION}.delta_z": action[2],
|
||||
}
|
||||
if self.use_gripper:
|
||||
delta_action[f"{ACTION}.gripper"] = action[3]
|
||||
return delta_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{ACTION}.delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
if self.use_gripper:
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
|
||||
"""
|
||||
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
|
||||
for use with inverse kinematics processors.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.delta_x": float,
|
||||
"action.delta_y": float,
|
||||
"action.delta_z": float,
|
||||
"action.gripper": float (optional),
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.target_x": float,
|
||||
"action.target_y": float,
|
||||
"action.target_z": float,
|
||||
"action.target_wx": float,
|
||||
"action.target_wy": float,
|
||||
"action.target_wz": float,
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
||||
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
||||
|
||||
def action(self, action: dict) -> dict:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
delta_x = action.pop(f"{ACTION}.delta_x", 0.0)
|
||||
delta_y = action.pop(f"{ACTION}.delta_y", 0.0)
|
||||
delta_z = action.pop(f"{ACTION}.delta_z", 0.0)
|
||||
gripper = action.pop(f"{ACTION}.gripper", 1.0) # Default to "stay" (1.0)
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
|
||||
enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
|
||||
|
||||
# Scale the deltas appropriately
|
||||
scaled_delta_x = delta_x * self.position_scale
|
||||
scaled_delta_y = delta_y * self.position_scale
|
||||
scaled_delta_z = delta_z * self.position_scale
|
||||
|
||||
# For gamepad/keyboard, we don't have rotation input, so set to 0
|
||||
# These could be extended in the future for more sophisticated teleoperators
|
||||
target_wx = 0.0
|
||||
target_wy = 0.0
|
||||
target_wz = 0.0
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
f"{ACTION}.enabled": enabled,
|
||||
f"{ACTION}.target_x": scaled_delta_x,
|
||||
f"{ACTION}.target_y": scaled_delta_y,
|
||||
f"{ACTION}.target_z": scaled_delta_z,
|
||||
f"{ACTION}.target_wx": target_wx,
|
||||
f"{ACTION}.target_wy": target_wy,
|
||||
f"{ACTION}.target_wz": target_wz,
|
||||
f"{ACTION}.gripper": float(gripper),
|
||||
}
|
||||
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transform features to match output format."""
|
||||
features.pop(f"{ACTION}.delta_x", None)
|
||||
features.pop(f"{ACTION}.delta_y", None)
|
||||
features.pop(f"{ACTION}.delta_z", None)
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
|
||||
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
@@ -19,64 +19,116 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("device_processor")
|
||||
@dataclass
|
||||
class DeviceProcessor:
|
||||
"""Processes transitions by moving tensors to the specified device.
|
||||
class DeviceProcessorStep(ProcessorStep):
|
||||
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
|
||||
|
||||
This processor ensures that all tensors in the transition are moved to the
|
||||
specified device (CPU or GPU) before they are returned.
|
||||
specified device (CPU or GPU) before they are returned. It can also convert
|
||||
floating-point tensors to a specified dtype while preserving non-float types
|
||||
(int, long, bool, etc.).
|
||||
"""
|
||||
|
||||
device: torch.device = "cpu"
|
||||
device: str = "cpu"
|
||||
float_dtype: str | None = None
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"half": torch.float16,
|
||||
"float": torch.float32,
|
||||
"double": torch.float64,
|
||||
}
|
||||
|
||||
def __post_init__(self):
|
||||
self.device = get_safe_torch_device(self.device)
|
||||
self.tensor_device: torch.device = get_safe_torch_device(self.device)
|
||||
self.device = self.tensor_device.type # cuda might have changed to cuda:1
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Validate and convert float_dtype string to torch dtype
|
||||
if self.float_dtype is not None:
|
||||
if self.float_dtype not in self.DTYPE_MAPPING:
|
||||
raise ValueError(
|
||||
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
|
||||
)
|
||||
|
||||
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
|
||||
else:
|
||||
self._target_float_dtype = None
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Process a tensor by moving to device and optionally converting float dtype.
|
||||
|
||||
If the tensor is already on a GPU and we're configured for a GPU, it preserves
|
||||
that GPU placement (useful for multi-GPU training with Accelerate).
|
||||
Otherwise, it moves to the configured device.
|
||||
"""
|
||||
# Determine target device
|
||||
if tensor.is_cuda and self.tensor_device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement
|
||||
# This handles multi-GPU scenarios where Accelerate has already placed
|
||||
# tensors on the correct GPU for each process
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# Either tensor is on CPU, or we're configured for CPU
|
||||
# In both cases, use the configured device
|
||||
target_device = self.tensor_device
|
||||
|
||||
# Only move if necessary
|
||||
if tensor.device != target_device:
|
||||
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
|
||||
|
||||
# Convert float dtype if specified and tensor is floating point
|
||||
if self._target_float_dtype is not None and tensor.is_floating_point():
|
||||
tensor = tensor.to(dtype=self._target_float_dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Create a copy of the transition
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Process observation tensors
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()
|
||||
}
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
simple_tensor_keys = [
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.REWARD,
|
||||
TransitionKey.DONE,
|
||||
TransitionKey.TRUNCATED,
|
||||
]
|
||||
|
||||
# Process action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
|
||||
dict_tensor_keys = [
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
]
|
||||
|
||||
# Process reward tensor
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is not None and isinstance(reward, torch.Tensor):
|
||||
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
|
||||
# Process simple tensors
|
||||
for key in simple_tensor_keys:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_transition[key] = self._process_tensor(value)
|
||||
|
||||
# Process done tensor
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is not None and isinstance(done, torch.Tensor):
|
||||
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process truncated tensor
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is not None and isinstance(truncated, torch.Tensor):
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated.to(
|
||||
self.device, non_blocking=self.non_blocking
|
||||
)
|
||||
# Process dictionary-like tensors
|
||||
for key in dict_tensor_keys:
|
||||
data_dict = transition.get(key)
|
||||
if data_dict is not None:
|
||||
new_data_dict = {
|
||||
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in data_dict.items()
|
||||
}
|
||||
new_transition[key] = new_data_dict
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {"device": self.device}
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
#! /usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessorStep(ActionProcessorStep):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
def action(self, action: torch.Tensor) -> np.ndarray:
|
||||
if not isinstance(action, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
and len(numpy_action.shape) > 1
|
||||
and numpy_action.shape[0] == 1
|
||||
):
|
||||
numpy_action = numpy_action.squeeze(0)
|
||||
|
||||
return numpy_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessorStep(ActionProcessorStep):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
|
||||
def action(self, action: np.ndarray) -> torch.Tensor:
|
||||
if not isinstance(action, np.ndarray):
|
||||
raise TypeError(
|
||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
return torch_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -0,0 +1,382 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
InfoProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TruncatedProcessorStep,
|
||||
)
|
||||
|
||||
GRIPPER_KEY = "gripper"
|
||||
DISCRETE_PENALTY_KEY = "discrete_penalty"
|
||||
TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasTeleopEvents(Protocol):
|
||||
"""Minimal protocol for objects that provide teleoperation events.
|
||||
|
||||
This protocol only defines the additional get_teleop_events() method,
|
||||
avoiding duplication of the entire Teleoperator interface.
|
||||
"""
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""Get extra control events from the teleoperator.
|
||||
|
||||
Returns:
|
||||
Dictionary containing control events such as:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Type variable constrained to Teleoperator subclasses that also implement events
|
||||
TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
"""Runtime check that a teleoperator implements get_teleop_events."""
|
||||
if not isinstance(teleop, HasTeleopEvents):
|
||||
raise TypeError(
|
||||
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
|
||||
f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop"
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Add teleoperator action to transition complementary data."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
"""Add teleoperator control events to transition info.
|
||||
|
||||
This processor step extracts control events from teleoperators that support
|
||||
event-based interaction (intervention detection, episode termination, etc.).
|
||||
|
||||
Works with any teleoperator that inherits from Teleoperator and implements the
|
||||
get_teleop_events() method, including custom user-defined teleoperators.
|
||||
|
||||
Built-in compatible teleoperators:
|
||||
- GamepadTeleop: Uses gamepad buttons for control events
|
||||
- KeyboardEndEffectorTeleop: Uses keyboard keys for control events
|
||||
"""
|
||||
|
||||
teleop_device: TeleopWithEvents
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that the teleoperator supports events."""
|
||||
_check_teleop_with_events(self.teleop_device)
|
||||
|
||||
def info(self, info: dict) -> dict:
|
||||
new_info = dict(info)
|
||||
|
||||
teleop_events = self.teleop_device.get_teleop_events()
|
||||
new_info.update(teleop_events)
|
||||
return new_info
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
"""Crop and resize image observations."""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Process all image keys in the observation
|
||||
for key in observation:
|
||||
if "image" not in key:
|
||||
continue
|
||||
|
||||
image = observation[key]
|
||||
device = image.device
|
||||
# NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
|
||||
if device.type == "mps":
|
||||
image = image.cpu()
|
||||
# Crop if crop params are provided for this key
|
||||
if self.crop_params_dict is not None and key in self.crop_params_dict:
|
||||
crop_params = self.crop_params_dict[key]
|
||||
image = F.crop(image, *crop_params)
|
||||
if self.resize_size is not None:
|
||||
image = F.resize(image, self.resize_size)
|
||||
image = image.clamp(0.0, 1.0)
|
||||
new_observation[key] = image.to(device)
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if self.resize_size is None:
|
||||
return features
|
||||
for key in features:
|
||||
if "image" in key:
|
||||
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
"""Track episode steps and enforce time limits."""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def truncated(self, truncated):
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
# TODO (steven): missing an else truncated = False?
|
||||
return truncated
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
|
||||
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessorStep(ProcessorStep):
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
# Get intervention signals from complementary data
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
|
||||
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
||||
success = info.get(TeleopEvents.SUCCESS, False)
|
||||
rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
|
||||
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Override action if intervention is active
|
||||
if is_intervention and teleop_action is not None:
|
||||
if isinstance(teleop_action, dict):
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get(f"{ACTION}.delta_x", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_y", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
action_list = teleop_action.tolist()
|
||||
else:
|
||||
action_list = teleop_action
|
||||
|
||||
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||
|
||||
# Handle episode termination
|
||||
new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
|
||||
self.terminate_on_success and success
|
||||
)
|
||||
new_transition[TransitionKey.REWARD] = float(success)
|
||||
|
||||
# Update info with intervention metadata
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info[TeleopEvents.IS_INTERVENTION] = is_intervention
|
||||
info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
|
||||
info[TeleopEvents.SUCCESS] = success
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
# Update complementary data with teleop action
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessorStep(ProcessorStep):
|
||||
"""Apply reward classification to image observations."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
device: str = "cpu"
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
reward_classifier: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the reward classifier after dataclass initialization."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
return new_transition
|
||||
|
||||
# Extract images from observation
|
||||
images = {key: value for key, value in observation.items() if "image" in key}
|
||||
|
||||
if not images:
|
||||
return new_transition
|
||||
|
||||
# Run reward classifier
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
|
||||
|
||||
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# Calculate reward and termination
|
||||
reward = new_transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = new_transition.get(TransitionKey.DONE, False)
|
||||
|
||||
if math.isclose(success, 1, abs_tol=1e-2):
|
||||
reward = self.success_reward
|
||||
if self.terminate_on_success:
|
||||
terminated = True
|
||||
|
||||
# Update transition
|
||||
new_transition[TransitionKey.REWARD] = reward
|
||||
new_transition[TransitionKey.DONE] = terminated
|
||||
|
||||
# Update info with classifier frequency
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info["reward_classifier_frequency"] = classifier_frequency
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"device": self.device,
|
||||
"success_threshold": self.success_threshold,
|
||||
"success_reward": self.success_reward,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -0,0 +1,109 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessorStep(ObservationProcessorStep):
|
||||
"""Add joint velocity information to observations."""
|
||||
|
||||
dt: float = 0.1
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get(OBS_STATE)
|
||||
if current_positions is None:
|
||||
# TODO(steven): if we get here, then the transform_features method will not hold
|
||||
raise ValueError(f"{OBS_STATE} is not in observation")
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
joint_velocities = torch.zeros_like(current_positions)
|
||||
else:
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Extend observation with velocities
|
||||
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if OBS_STATE in features:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Double the shape to account for positions + velocities
|
||||
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
|
||||
|
||||
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessorStep(ObservationProcessorStep):
|
||||
"""Add motor current information to observations."""
|
||||
|
||||
robot: Robot | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot is not set")
|
||||
|
||||
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get(OBS_STATE)
|
||||
if current_state is None:
|
||||
return observation
|
||||
|
||||
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if OBS_STATE in features and self.robot is not None:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Add motor current dimensions to the original state shape
|
||||
num_motors = 0
|
||||
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
|
||||
num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
|
||||
|
||||
if num_motors > 0:
|
||||
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
|
||||
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
@@ -0,0 +1,503 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
|
||||
|
||||
This script:
|
||||
1. Loads an existing pretrained policy model
|
||||
2. Extracts normalization statistics from the model
|
||||
3. Creates both preprocessor and postprocessor:
|
||||
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
|
||||
- Postprocessor: unnormalizes outputs (actions) for inference
|
||||
4. Removes normalization layers from the model state_dict
|
||||
5. Saves the new model and both processors
|
||||
|
||||
Usage:
|
||||
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--policy-type act \
|
||||
--push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from .pipeline import PolicyProcessorPipeline
|
||||
from .rename_processor import RenameProcessorStep
|
||||
|
||||
# Policy type to class mapping
|
||||
POLICY_CLASSES = {
|
||||
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
|
||||
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
|
||||
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
|
||||
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
|
||||
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
|
||||
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
|
||||
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
|
||||
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
|
||||
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
|
||||
}
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Extract normalization statistics from model state_dict."""
|
||||
stats = {}
|
||||
|
||||
# Define patterns to match and their prefixes to remove
|
||||
normalization_patterns = [
|
||||
"normalize_inputs.buffer_",
|
||||
"unnormalize_outputs.buffer_",
|
||||
"normalize_targets.buffer_",
|
||||
"normalize.", # Must come after normalize_* patterns
|
||||
"unnormalize.", # Must come after unnormalize_* patterns
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
]
|
||||
|
||||
# Process each key in state_dict
|
||||
for key, tensor in state_dict.items():
|
||||
# Try each pattern
|
||||
for pattern in normalization_patterns:
|
||||
if key.startswith(pattern):
|
||||
# Extract the remaining part after the pattern
|
||||
remaining = key[len(pattern) :]
|
||||
parts = remaining.split(".")
|
||||
|
||||
# Need at least feature name and stat type
|
||||
if len(parts) >= 2:
|
||||
# Last part is the stat type (mean, std, min, max, etc.)
|
||||
stat_type = parts[-1]
|
||||
# Everything else is the feature name
|
||||
feature_name = ".".join(parts[:-1]).replace("_", ".")
|
||||
|
||||
# Add to stats
|
||||
if feature_name not in stats:
|
||||
stats[feature_name] = {}
|
||||
stats[feature_name][stat_type] = tensor.clone()
|
||||
|
||||
# Only process the first matching pattern
|
||||
break
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def detect_features_and_norm_modes(
|
||||
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
||||
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
||||
"""Detect features and normalization modes from config and stats."""
|
||||
features = {}
|
||||
norm_modes = {}
|
||||
|
||||
# First, check if there's a normalization_mapping in the config
|
||||
if "normalization_mapping" in config:
|
||||
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
|
||||
# Extract normalization modes from config
|
||||
for feature_name, mode_str in config["normalization_mapping"].items():
|
||||
# Convert string to NormalizationMode enum
|
||||
if mode_str == "mean_std":
|
||||
mode = NormalizationMode.MEAN_STD
|
||||
elif mode_str == "min_max":
|
||||
mode = NormalizationMode.MIN_MAX
|
||||
else:
|
||||
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
|
||||
continue
|
||||
|
||||
# Determine feature type from feature name
|
||||
if "image" in feature_name or "visual" in feature_name:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in feature_name:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in feature_name:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
norm_modes[feature_type] = mode
|
||||
|
||||
# Try to extract from config
|
||||
if "features" in config:
|
||||
for key, feature_config in config["features"].items():
|
||||
shape = feature_config.get("shape", feature_config.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
|
||||
# Determine feature type
|
||||
if "image" in key or "visual" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE # Default
|
||||
|
||||
features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
# If no features in config, infer from stats
|
||||
if not features:
|
||||
for key, stat_dict in stats.items():
|
||||
# Get shape from any stat tensor
|
||||
tensor = next(iter(stat_dict.values()))
|
||||
shape = tuple(tensor.shape)
|
||||
|
||||
# Determine feature type based on key
|
||||
if "image" in key or "visual" in key or "pixels" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key or "joint" in key or "position" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
# If normalization modes weren't in config, determine based on available stats
|
||||
if not norm_modes:
|
||||
for key, stat_dict in stats.items():
|
||||
if key in features:
|
||||
if "mean" in stat_dict and "std" in stat_dict:
|
||||
feature_type = features[key].type
|
||||
if feature_type not in norm_modes:
|
||||
norm_modes[feature_type] = NormalizationMode.MEAN_STD
|
||||
elif "min" in stat_dict and "max" in stat_dict:
|
||||
feature_type = features[key].type
|
||||
if feature_type not in norm_modes:
|
||||
norm_modes[feature_type] = NormalizationMode.MIN_MAX
|
||||
|
||||
# Default normalization modes if not detected
|
||||
if FeatureType.VISUAL not in norm_modes:
|
||||
norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
|
||||
if FeatureType.STATE not in norm_modes:
|
||||
norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
|
||||
if FeatureType.ACTION not in norm_modes:
|
||||
norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
|
||||
|
||||
return features, norm_modes
|
||||
|
||||
|
||||
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Remove normalization layers from state_dict."""
|
||||
new_state_dict = {}
|
||||
|
||||
# Patterns to remove
|
||||
remove_patterns = [
|
||||
"normalize_inputs.",
|
||||
"unnormalize_outputs.",
|
||||
"normalize_targets.", # Added pattern for target normalization
|
||||
"normalize.",
|
||||
"unnormalize.",
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
"normalizer.",
|
||||
]
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
should_remove = any(pattern in key for pattern in remove_patterns)
|
||||
if not should_remove:
|
||||
new_state_dict[key] = tensor
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert features from old format to PolicyFeature objects."""
|
||||
converted_features = {}
|
||||
|
||||
for key, feature_dict in features_dict.items():
|
||||
# Determine feature type based on key
|
||||
if "image" in key or "visual" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
# Get shape from feature dict
|
||||
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
|
||||
converted_features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
return converted_features
|
||||
|
||||
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
"""Load model state_dict and config from hub."""
|
||||
# Download files
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
|
||||
# Load state_dict
|
||||
state_dict = load_safetensors(safetensors_path)
|
||||
|
||||
# Load config
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
|
||||
return state_dict, config, train_config
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Migrate policy models with normalization layers to new pipeline system"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model (hub repo or local directory)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory for migrated model (default: same as pretrained-path)",
|
||||
)
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
|
||||
parser.add_argument(
|
||||
"--hub-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hub repository ID for pushing (default: same as pretrained-path)",
|
||||
)
|
||||
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
||||
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model and config
|
||||
print(f"Loading model from {args.pretrained_path}...")
|
||||
if os.path.isdir(args.pretrained_path):
|
||||
# Local directory
|
||||
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
||||
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
||||
train_config = json.load(f)
|
||||
else:
|
||||
# Hub repository
|
||||
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
||||
|
||||
# Extract normalization statistics
|
||||
print("Extracting normalization statistics...")
|
||||
stats = extract_normalization_stats(state_dict)
|
||||
|
||||
print(f"Found normalization statistics for: {list(stats.keys())}")
|
||||
|
||||
# Detect input features and normalization modes
|
||||
print("Detecting features and normalization modes...")
|
||||
features, norm_map = detect_features_and_norm_modes(config, stats)
|
||||
|
||||
print(f"Detected features: {list(features.keys())}")
|
||||
print(f"Normalization modes: {norm_map}")
|
||||
|
||||
# Remove normalization layers from state_dict
|
||||
print("Removing normalization layers from model...")
|
||||
new_state_dict = remove_normalization_layers(state_dict)
|
||||
|
||||
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
|
||||
if removed_keys:
|
||||
print(f"Removed {len(removed_keys)} normalization layer keys")
|
||||
|
||||
# Determine output path
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
else:
|
||||
if os.path.isdir(args.pretrained_path):
|
||||
output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
|
||||
else:
|
||||
output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Clean up config - remove normalization_mapping field
|
||||
cleaned_config = dict(config)
|
||||
if "normalization_mapping" in cleaned_config:
|
||||
print("Removing 'normalization_mapping' field from config")
|
||||
del cleaned_config["normalization_mapping"]
|
||||
policy_type = deepcopy(cleaned_config["type"])
|
||||
|
||||
del cleaned_config["type"]
|
||||
|
||||
# Instantiate the policy model with cleaned config and load the cleaned state dict
|
||||
print(f"Instantiating {policy_type} policy model...")
|
||||
policy_class_path = POLICY_CLASSES[policy_type]
|
||||
module_path, class_name = policy_class_path.rsplit(".", 1)
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
policy_class = getattr(module, class_name)
|
||||
|
||||
# Create config class instance
|
||||
config_module_path = module_path.replace("modeling", "configuration")
|
||||
config_module = importlib.import_module(config_module_path)
|
||||
# Handle special cases for config class names
|
||||
config_class_names = {
|
||||
"act": "ACTConfig",
|
||||
"diffusion": "DiffusionConfig",
|
||||
"pi0": "PI0Config",
|
||||
"pi0fast": "PI0FASTConfig",
|
||||
"smolvla": "SmolVLAConfig",
|
||||
"tdmpc": "TDMPCConfig",
|
||||
"vqbet": "VQBeTConfig",
|
||||
"sac": "SACConfig",
|
||||
"classifier": "ClassifierConfig",
|
||||
}
|
||||
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
|
||||
config_class = getattr(config_module, config_class_name)
|
||||
|
||||
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
|
||||
if "input_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'input_features' in config")
|
||||
if "output_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'output_features' in config")
|
||||
|
||||
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
|
||||
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
|
||||
|
||||
# Create config instance from cleaned config dict
|
||||
policy_config = config_class(**cleaned_config)
|
||||
|
||||
# Create policy instance - some policies expect dataset_stats
|
||||
policy = policy_class(policy_config)
|
||||
|
||||
# Load the cleaned state dict
|
||||
policy.load_state_dict(new_state_dict, strict=True)
|
||||
print("Successfully loaded cleaned state dict into policy model")
|
||||
|
||||
# Now create preprocessor and postprocessor with cleaned_config available
|
||||
print("Creating preprocessor and postprocessor...")
|
||||
# The pattern from existing processor factories:
|
||||
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
|
||||
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
|
||||
|
||||
# Get features from cleaned_config (now they're PolicyFeature objects)
|
||||
input_features = cleaned_config.get("input_features", {})
|
||||
output_features = cleaned_config.get("output_features", {})
|
||||
|
||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||
preprocessor_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**input_features, **output_features},
|
||||
norm_map=norm_map,
|
||||
stats=stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=policy_config.device),
|
||||
]
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
|
||||
|
||||
# Create postprocessor with unnormalizer for outputs only
|
||||
postprocessor_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
|
||||
]
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
|
||||
|
||||
# Determine hub repo ID if pushing to hub
|
||||
if args.push_to_hub:
|
||||
if args.hub_repo_id:
|
||||
hub_repo_id = args.hub_repo_id
|
||||
else:
|
||||
if not os.path.isdir(args.pretrained_path):
|
||||
# Use same repo with "_migrated" suffix
|
||||
hub_repo_id = f"{args.pretrained_path}_migrated"
|
||||
else:
|
||||
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
||||
else:
|
||||
hub_repo_id = None
|
||||
|
||||
# Save preprocessor and postprocessor to root directory
|
||||
print(f"Saving preprocessor to {output_dir}...")
|
||||
preprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
print(f"Saving postprocessor to {output_dir}...")
|
||||
postprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
# Save model using the policy's save_pretrained method
|
||||
print(f"Saving model to {output_dir}...")
|
||||
policy.save_pretrained(
|
||||
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
||||
)
|
||||
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
# Get metadata from original config
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
license = config.get("license", "apache-2.0")
|
||||
|
||||
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||
tags = set(tags).union({"robotics", "lerobot", policy_type})
|
||||
tags = list(tags)
|
||||
|
||||
# Generate model card
|
||||
card = policy.generate_model_card(
|
||||
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
|
||||
)
|
||||
|
||||
# Save model card locally
|
||||
card.save(str(output_dir / "README.md"))
|
||||
print(f"Model card saved to {output_dir / 'README.md'}")
|
||||
# Push model card to hub if requested
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_dir / "README.md"),
|
||||
path_in_repo="README.md",
|
||||
repo_id=hub_repo_id,
|
||||
repo_type="model",
|
||||
commit_message="Add model card for migrated model",
|
||||
)
|
||||
print("Model card pushed to hub")
|
||||
|
||||
print("\nMigration complete!")
|
||||
print(f"Migrated model saved to: {output_dir}")
|
||||
if args.push_to_hub:
|
||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,179 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
"""Convert numpy arrays and other types to torch tensors."""
|
||||
tensor_stats: dict[str, dict[str, Tensor]] = {}
|
||||
for key, sub in stats.items():
|
||||
tensor_stats[key] = {}
|
||||
for stat_name, value in sub.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
tensor_val = torch.from_numpy(value.astype(np.float32))
|
||||
elif isinstance(value, torch.Tensor):
|
||||
tensor_val = value.to(dtype=torch.float32)
|
||||
elif isinstance(value, (int, float, list, tuple)):
|
||||
tensor_val = torch.tensor(value, dtype=torch.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
|
||||
tensor_stats[key][stat_name] = tensor_val
|
||||
return tensor_stats
|
||||
from .converters import to_tensor
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessor:
|
||||
"""Normalizes observations and actions in a single processor step.
|
||||
class _NormalizationMixin:
|
||||
"""
|
||||
A mixin class providing core functionality for normalization and unnormalization.
|
||||
|
||||
This processor handles normalization of both observation and action tensors
|
||||
using either mean/std normalization or min/max scaling to a [-1, 1] range.
|
||||
|
||||
For each tensor key in the stats dictionary, the processor will:
|
||||
- Use mean/std normalization if those statistics are provided: (x - mean) / std
|
||||
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
|
||||
|
||||
The processor can be configured to normalize only specific keys by setting
|
||||
the normalize_keys parameter.
|
||||
This class manages normalization statistics, their conversion to tensors, device placement,
|
||||
and the application of normalization transformations. It is designed to be inherited by
|
||||
concrete ProcessorStep implementations.
|
||||
"""
|
||||
|
||||
# Features and normalisation map are mandatory to match the design of normalize.py
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
|
||||
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||
# membership checks O(1).
|
||||
normalize_keys: set[str] | None = None
|
||||
|
||||
device: torch.device | str | None = None
|
||||
eps: float = 1e-8
|
||||
normalize_observation_keys: set[str] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
) -> NormalizerProcessor:
|
||||
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||
|
||||
The features and norm_map parameters are mandatory to match the design
|
||||
pattern used in normalize.py.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_keys=normalize_keys,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
# Robust JSON deserialization handling (guard empty maps)
|
||||
if self.features:
|
||||
first_val = next(iter(self.features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed
|
||||
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
if self.norm_map:
|
||||
# if keys are strings (JSON), rebuild enum map
|
||||
if all(isinstance(k, str) for k in self.norm_map.keys()):
|
||||
reconstructed = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed
|
||||
|
||||
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
||||
# during runtime.
|
||||
# Convert stats to tensors and move to the target device once during initialization.
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
|
||||
# Ensure *normalize_keys* is a set for fast look-ups and compare by
|
||||
# value later when returning the configuration.
|
||||
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
|
||||
self.normalize_keys = set(self.normalize_keys)
|
||||
def to(self, device: torch.device | str) -> _NormalizationMixin:
|
||||
"""Moves the processor's normalization stats to the specified device and returns self."""
|
||||
self.device = device
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
return self
|
||||
|
||||
def _normalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat: dict[str, Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
|
||||
return flat
|
||||
|
||||
# Decide which keys should be normalised for this call.
|
||||
if self.normalize_keys is not None:
|
||||
keys_to_norm = self.normalize_keys
|
||||
else:
|
||||
# Use feature map to skip action keys.
|
||||
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
|
||||
|
||||
processed = dict(observation)
|
||||
for key in keys_to_norm:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
def load_state_dict(self, state: dict[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
# Load to the processor's configured device.
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
return processed
|
||||
|
||||
def _normalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with normalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
config = {
|
||||
@@ -183,45 +88,87 @@ class NormalizerProcessor:
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
if self.normalize_keys is not None:
|
||||
# Serialise as a list for YAML / JSON friendliness
|
||||
config["normalize_keys"] = sorted(self.normalize_keys)
|
||||
if self.normalize_observation_keys is not None:
|
||||
config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
|
||||
return config
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
|
||||
new_observation = dict(observation)
|
||||
for key, feature in self.features.items():
|
||||
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
|
||||
continue
|
||||
if feature.type != FeatureType.ACTION and key in new_observation:
|
||||
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32)
|
||||
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
|
||||
return new_observation
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
|
||||
tensor = torch.as_tensor(action, dtype=torch.float32)
|
||||
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
|
||||
return processed_action
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
def _apply_transform(
|
||||
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
|
||||
) -> Tensor:
|
||||
"""Core logic to apply normalization or unnormalization."""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
return tensor
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# Ensure input tensor is on the same device as the stats.
|
||||
if self.device and tensor.device != self.device:
|
||||
tensor = tensor.to(self.device)
|
||||
|
||||
# For Accelerate compatibility: move stats to match input tensor device
|
||||
input_device = tensor.device
|
||||
stats = self._tensor_stats[key]
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
|
||||
# Move stats to input device if needed
|
||||
stats_device = next(iter(stats.values())).device
|
||||
if stats_device != input_device:
|
||||
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
if inverse:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
denom = max_val - min_val
|
||||
# When min_val == max_val, substitute the denominator with a small epsilon
|
||||
# to prevent division by zero. This consistently maps an input equal to
|
||||
# min_val to -1, ensuring a stable transformation.
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom
|
||||
)
|
||||
if inverse:
|
||||
# Map from [-1, 1] back to [min, max]
|
||||
return (tensor + 1) / 2 * denom + min_val
|
||||
# Map from [min, max] to [-1, 1]
|
||||
return 2 * (tensor - min_val) / denom - 1
|
||||
|
||||
# If necessary stats are missing, return input unchanged.
|
||||
return tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessor:
|
||||
"""Inverse normalisation for observations and actions.
|
||||
|
||||
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
|
||||
transform.
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies normalization to observations and actions in a transition.
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
This class directly implements the normalization logic for both observation and action
|
||||
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
|
||||
initialization.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
@@ -229,103 +176,97 @@ class UnnormalizerProcessor:
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
) -> UnnormalizerProcessor:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def _unnormalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||
processed = dict(observation)
|
||||
for key in keys:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
continue
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
return processed
|
||||
|
||||
def _unnormalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
return action
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
*,
|
||||
normalize_observation_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
device: torch.device | str | None = None,
|
||||
) -> NormalizerProcessorStep:
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_observation_keys=normalize_observation_keys,
|
||||
eps=eps,
|
||||
device=device,
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return tensor * std + mean
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with unnormalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Handle observation normalization.
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(
|
||||
observation, inverse=False
|
||||
)
|
||||
|
||||
# Handle action normalization.
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies unnormalization (the inverse of normalization) to
|
||||
observations and actions in a transition.
|
||||
|
||||
This is typically used to transform actions from a normalized policy output back into
|
||||
the original scale for execution in an environment.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
device: torch.device | str | None = None,
|
||||
) -> UnnormalizerProcessorStep:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Handle observation unnormalization.
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True)
|
||||
|
||||
# Handle action unnormalization.
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def hotswap_stats(
|
||||
policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
|
||||
) -> PolicyProcessorPipeline:
|
||||
"""
|
||||
Replaces normalization statistics in a PolicyProcessor pipeline.
|
||||
|
||||
This function creates a deep copy of the provided `PolicyProcessorPipeline` and updates the
|
||||
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` steps within it.
|
||||
It's useful for adapting a trained policy to a new environment or dataset with
|
||||
different data distributions.
|
||||
"""
|
||||
rp = deepcopy(policy_processor)
|
||||
for step in rp.steps:
|
||||
if isinstance(step, _NormalizationMixin):
|
||||
step.stats = stats
|
||||
# Re-initialize tensor_stats on the correct device.
|
||||
step._tensor_stats = to_tensor(stats, device=step.device)
|
||||
return rp
|
||||
|
||||
@@ -22,12 +22,13 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessor(ObservationProcessor):
|
||||
class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes environment observations into the LeRobot format by handling both images and states.
|
||||
|
||||
@@ -106,9 +107,8 @@ class VanillaObservationProcessor(ObservationProcessor):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
|
||||
+279
-425
File diff suppressed because it is too large
Load Diff
@@ -13,19 +13,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rename_processor")
|
||||
class RenameProcessor(ObservationProcessor):
|
||||
class RenameProcessorStep(ObservationProcessorStep):
|
||||
"""Rename processor that renames keys in the observation."""
|
||||
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
@@ -43,9 +42,20 @@ class RenameProcessor(ObservationProcessor):
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
"""
|
||||
return {self.rename_map.get(k, k): v for k, v in features.items()}
|
||||
|
||||
|
||||
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
|
||||
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
|
||||
if not stats:
|
||||
return {}
|
||||
renamed: dict[str, dict[str, Any]] = {}
|
||||
for old_key, sub_stats in stats.items():
|
||||
new_key = rename_map.get(old_key, old_key)
|
||||
renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {}
|
||||
return renamed
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Tokenizer processor for handling text tokenization in robot transitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
else:
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="tokenizer_processor")
|
||||
class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
|
||||
|
||||
This processor handles tokenization of task strings found in the complementary_data
|
||||
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
|
||||
to the observation data for model processing while preserving the original task string.
|
||||
|
||||
The processor supports both single strings and lists of strings as task inputs.
|
||||
|
||||
Args:
|
||||
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
|
||||
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
|
||||
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
|
||||
tokenizer: A tokenizer object (e.g., from transformers library) that implements
|
||||
the __call__ method. If provided, tokenizer_name is ignored. This parameter
|
||||
is not serialized and must be provided via overrides when loading.
|
||||
max_length: Maximum sequence length for tokenization. Defaults to 512.
|
||||
task_key: Key in complementary_data containing the task text. Defaults to "task".
|
||||
padding: Padding strategy for tokenization. Defaults to "max_length".
|
||||
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
|
||||
|
||||
Examples:
|
||||
Using tokenizer name (auto-loaded):
|
||||
```python
|
||||
processor = TokenizerProcessorStep(tokenizer_name="bert-base-uncased", max_length=128)
|
||||
```
|
||||
|
||||
Using custom tokenizer object:
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
processor = TokenizerProcessorStep(tokenizer=custom_tokenizer, max_length=128)
|
||||
```
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
|
||||
# Internal tokenizer instance (not serialized)
|
||||
input_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.input_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("AutoTokenizer is not available")
|
||||
self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
def get_task(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""Extract and normalize task from complementary data.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data.
|
||||
|
||||
Returns:
|
||||
List of task strings if task is present, None otherwise.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
if self.task_key not in complementary_data:
|
||||
return None
|
||||
|
||||
task = complementary_data[self.task_key]
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
# Convert to list of strings
|
||||
if isinstance(task, str):
|
||||
return [task]
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
return task
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation):
|
||||
"""Process the transition by tokenizing the task text.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data with task text.
|
||||
|
||||
Returns:
|
||||
Modified transition with tokenized task added to observation.
|
||||
|
||||
Raises:
|
||||
ValueError: If tokenizer initialization failed.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
return observation
|
||||
|
||||
# Tokenize the task (creates CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Detect device from existing tensors in the transition
|
||||
target_device = self._detect_device(self.transition)
|
||||
|
||||
# Move tokenized tensors to match the device of other data
|
||||
if target_device is not None:
|
||||
tokenized_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_prompt.items()
|
||||
}
|
||||
|
||||
# Get or create observation dict
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Add tokenized data to observation
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
"""Detect device from existing tensors in the transition.
|
||||
|
||||
This allows the tokenized tensors to match the device of other data,
|
||||
which is especially important for multi-GPU training with Accelerate.
|
||||
|
||||
Args:
|
||||
transition: The transition to search for existing tensors.
|
||||
|
||||
Returns:
|
||||
The device of the first tensor found, or None if no tensors exist.
|
||||
"""
|
||||
# Check observation tensors first (most likely to exist)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation:
|
||||
for value in observation.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.device
|
||||
|
||||
return None # No tensors found, keep on CPU
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize text using the configured tokenizer.
|
||||
|
||||
Args:
|
||||
text: Text string or list of strings to tokenize.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
|
||||
"""
|
||||
return self.input_tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
truncation=self.truncation,
|
||||
padding=self.padding,
|
||||
padding_side=self.padding_side,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization.
|
||||
|
||||
Note: Only tokenizer_name is saved, not the tokenizer object itself.
|
||||
When loading, provide the tokenizer via overrides if needed.
|
||||
"""
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
"task_key": self.task_key,
|
||||
"padding_side": self.padding_side,
|
||||
"padding": self.padding,
|
||||
"truncation": self.truncation,
|
||||
}
|
||||
|
||||
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
|
||||
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract.
|
||||
|
||||
Args:
|
||||
features: Input feature dictionary.
|
||||
|
||||
Returns:
|
||||
Updated feature dictionary with tokenized task features added.
|
||||
"""
|
||||
# Add features for tokenized output if they don't exist
|
||||
# Standard tokenizer output includes tokens and attention_mask
|
||||
|
||||
if OBS_LANGUAGE_TOKENS not in features:
|
||||
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features:
|
||||
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
+162
-30
@@ -59,7 +59,7 @@ lerobot-record \
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
@@ -72,10 +72,23 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RobotProcessorPipeline,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
observation_to_transition,
|
||||
transition_to_dataset_frame,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -149,6 +162,8 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
@@ -187,6 +202,36 @@ class RecordConfig:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
""" --------------- record_loop() data flow --------------------------
|
||||
[ Robot ]
|
||||
V
|
||||
[ robot.get_observation() ] ---> raw_obs
|
||||
V
|
||||
[ robot_observation_processor ] ---> obs_transition
|
||||
V
|
||||
.-----( ACTION LOGIC )------------------.
|
||||
V V
|
||||
[ From Teleoperator ] [ From Policy ]
|
||||
| |
|
||||
| [teleop.get_action] -> raw_action | [predict_action]
|
||||
| | | |
|
||||
| V | V
|
||||
| [teleop_action_processor] | |
|
||||
| | | |
|
||||
'---> teleop_transition '---> policy_transition
|
||||
| |
|
||||
'-------------------------.-------------'
|
||||
V
|
||||
[ robot_action_processor ] --> robot_action_to_send
|
||||
V
|
||||
[ robot.send_action() ] -- (Robot Executes)
|
||||
V
|
||||
( Transitions are merged & added to Dataset )
|
||||
V
|
||||
( Rerun Log / Loop Wait )
|
||||
"""
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def record_loop(
|
||||
robot: Robot,
|
||||
@@ -195,15 +240,32 @@ def record_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
control_time_s: int | None = None,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None, # runs after teleop
|
||||
robot_action_processor: RobotProcessorPipeline | None = None, # runs before robot
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None, # runs after robot
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
):
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=lambda tr: tr
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=lambda tr: tr, to_output=transition_to_robot_action
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
teleop_arm = teleop_keyboard = None
|
||||
if isinstance(teleop, list):
|
||||
if isinstance(teleop, list): # For LeKiwi
|
||||
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
|
||||
teleop_arm = next(
|
||||
(
|
||||
@@ -219,9 +281,20 @@ def record_loop(
|
||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||
)
|
||||
|
||||
# if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset custom pipelines
|
||||
teleop_action_processor.reset()
|
||||
robot_action_processor.reset()
|
||||
robot_observation_processor.reset()
|
||||
|
||||
policy_transition = None
|
||||
teleop_transition = None
|
||||
obs_transition = None
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -232,51 +305,88 @@ def record_loop(
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
observation = robot.get_observation()
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
if dataset is not None:
|
||||
observation_frame = transition_to_dataset_frame(
|
||||
obs_transition, dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
|
||||
if policy is not None:
|
||||
action_values = predict_action(
|
||||
observation_frame,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
action = teleop.get_action()
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
|
||||
|
||||
action_names = dataset.features["action"]["names"]
|
||||
policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)}
|
||||
policy_transition = {
|
||||
TransitionKey.ACTION: policy_action,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
elif isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
|
||||
elif isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
sent_action = robot.send_action(action)
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
|
||||
if policy is not None and policy_transition is not None:
|
||||
robot_action_to_send = robot_action_processor(policy_transition)
|
||||
else:
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_ = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
frame = {**observation_frame, **action_frame}
|
||||
# If transition_to_dataset_frame is provided, use it to merge the transitions.
|
||||
merged = []
|
||||
if obs_transition is not None: # The observation from the robot
|
||||
merged.append(obs_transition)
|
||||
if teleop_transition is not None: # The action from teleop
|
||||
merged.append(teleop_transition)
|
||||
if policy_transition is not None: # The action from policy
|
||||
merged.append(policy_transition)
|
||||
frame = transition_to_dataset_frame(
|
||||
merged if len(merged) > 1 else merged[0], dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(observation, action)
|
||||
log_rerun_data([obs_transition, teleop_transition or policy_transition])
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -296,7 +406,15 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Add next.* features that are generated during recording
|
||||
transition_features = {
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
"next.truncated": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
dataset_features = {**action_features, **obs_features, **transition_features}
|
||||
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
@@ -328,6 +446,18 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
@@ -345,6 +475,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
|
||||
+20
-4
@@ -45,9 +45,10 @@ from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -83,13 +84,25 @@ class ReplayConfig:
|
||||
dataset: DatasetReplayConfig
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Optional processor for actions before sending to robot
|
||||
robot_action_processor: RobotProcessorPipeline | None = None
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
@parser.wrap()
|
||||
def replay(cfg: ReplayConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Initialize robot action processor with default if not provided
|
||||
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Reset processor
|
||||
robot_action_processor.reset()
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
@@ -104,7 +117,10 @@ def replay(cfg: ReplayConfig):
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
robot.send_action(action)
|
||||
# Process action through robot action processor
|
||||
processed_action = robot_action_processor(action)
|
||||
|
||||
robot.send_action(processed_action) # type: ignore[arg-type]
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
|
||||
@@ -14,6 +14,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
|
||||
from .config_so100_follower import SO100FollowerConfig
|
||||
from .so100_follower import SO100Follower
|
||||
from .so100_follower_end_effector import SO100FollowerEndEffector
|
||||
|
||||
@@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100_follower_end_effector")
|
||||
@dataclass
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
|
||||
# Path to URDF file for kinematics
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
urdf_path: str | None = None
|
||||
|
||||
# End-effector frame name in the URDF
|
||||
target_frame_name: str = "gripper_frame_link"
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field(
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
|
||||
max_gripper_pos: float = 50
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -0,0 +1,460 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
EnvTransition,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||
@dataclass
|
||||
class EEReferenceAndDelta(ActionProcessorStep):
|
||||
"""
|
||||
Compute the desired end-effector pose from the target pose and the current pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
end_effector_step_sizes: dict
|
||||
motor_names: list[str]
|
||||
use_latched_reference: bool = (
|
||||
True # If True, latch reference on enable; if False, always use current pose
|
||||
)
|
||||
|
||||
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, action):
|
||||
new_action = action.copy()
|
||||
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if "reference_joint_positions" in comp:
|
||||
q = comp["reference_joint_positions"]
|
||||
else:
|
||||
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
|
||||
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
|
||||
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
|
||||
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
|
||||
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
|
||||
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
|
||||
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
|
||||
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
|
||||
|
||||
desired = None
|
||||
|
||||
if enabled:
|
||||
ref = t_curr
|
||||
if self.use_latched_reference:
|
||||
# Latched reference mode: latch reference at the rising edge
|
||||
if not self._prev_enabled or self.reference_ee_pose is None:
|
||||
self.reference_ee_pose = t_curr.copy()
|
||||
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
|
||||
|
||||
delta_p = np.array(
|
||||
[
|
||||
tx * self.end_effector_step_sizes["x"],
|
||||
ty * self.end_effector_step_sizes["y"],
|
||||
tz * self.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = ref[:3, :3] @ r_abs
|
||||
desired[:3, 3] = ref[:3, 3] + delta_p
|
||||
|
||||
self._command_when_disabled = desired.copy()
|
||||
else:
|
||||
# While disabled, keep sending the same command to avoid drift.
|
||||
if self._command_when_disabled is None:
|
||||
# If we've never had an enabled command yet, freeze current FK pose once.
|
||||
self._command_when_disabled = t_curr.copy()
|
||||
desired = self._command_when_disabled.copy()
|
||||
|
||||
# Write action fields
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
new_action[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
new_action[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
new_action[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
|
||||
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
|
||||
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return new_action
|
||||
|
||||
def reset(self):
|
||||
self._prev_enabled = False
|
||||
self.reference_ee_pose = None
|
||||
self._command_when_disabled = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.enabled", None)
|
||||
features.pop(f"{ACTION}.target_x", None)
|
||||
features.pop(f"{ACTION}.target_y", None)
|
||||
features.pop(f"{ACTION}.target_z", None)
|
||||
features.pop(f"{ACTION}.target_wx", None)
|
||||
features.pop(f"{ACTION}.target_wy", None)
|
||||
features.pop(f"{ACTION}.target_wz", None)
|
||||
|
||||
features[f"{ACTION}.ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_bounds_and_safety")
|
||||
@dataclass
|
||||
class EEBoundsAndSafety(ActionProcessorStep):
|
||||
"""
|
||||
Clip the end-effector pose to the bounds and check for jumps.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
end_effector_bounds: dict
|
||||
max_ee_step_m: float = 0.05
|
||||
max_ee_twist_step_rad: float = 0.20
|
||||
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
raise ValueError(
|
||||
"Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
|
||||
)
|
||||
|
||||
pos = np.array([x, y, z], dtype=float)
|
||||
twist = np.array([wx, wy, wz], dtype=float)
|
||||
|
||||
# Clip position
|
||||
pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
|
||||
|
||||
# Check for jumps in position
|
||||
if self._last_pos is not None:
|
||||
dpos = pos - self._last_pos
|
||||
n = float(np.linalg.norm(dpos))
|
||||
if n > self.max_ee_step_m and n > 0:
|
||||
pos = self._last_pos + dpos * (self.max_ee_step_m / n)
|
||||
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
|
||||
|
||||
self._last_pos = pos
|
||||
self._last_twist = twist
|
||||
|
||||
act[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
act[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
act[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
act[f"{ACTION}.ee.wx"] = float(twist[0])
|
||||
act[f"{ACTION}.ee.wy"] = float(twist[1])
|
||||
act[f"{ACTION}.ee.wz"] = float(twist[2])
|
||||
return act
|
||||
|
||||
def reset(self):
|
||||
self._last_pos = None
|
||||
self._last_twist = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# check if features as f"{ACTION}.ee.{x,y,z,wx,wy,wz}"
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
||||
@dataclass
|
||||
class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
"""
|
||||
Compute the desired joint positions from the desired end-effector pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.joint_name_1.pos": float,
|
||||
"action.joint_name_2.pos": float,
|
||||
...
|
||||
"action.joint_name_n.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return new_transition
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if self.initial_guess_current_joints: # Use current joints as initial guess
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
else: # Use previous ik solution as initial guess
|
||||
if self.q_curr is None:
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Build desired 4x4 transform from pos + rotvec (twist)
|
||||
t_des = np.eye(4, dtype=float)
|
||||
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
t_des[:3, 3] = [x, y, z]
|
||||
|
||||
# Compute inverse kinematics
|
||||
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
|
||||
self.q_curr = q_target
|
||||
|
||||
new_act = dict(act)
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
# TODO(pepijn): Investigate if this is correct
|
||||
# Do we want an observation key in the action field?
|
||||
new_act[f"{ACTION}.gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for name in self.motor_names:
|
||||
features[f"{ACTION}.{name}.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
|
||||
@dataclass
|
||||
class GripperVelocityToJoint(ProcessorStep):
|
||||
"""
|
||||
Convert the gripper velocity to a joint velocity.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.gripper": float,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.gripper.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
motor_names: list[str]
|
||||
speed_factor: float = 20.0
|
||||
clip_min: float = 0.0
|
||||
clip_max: float = 100.0
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if f"{ACTION}.gripper" not in act:
|
||||
raise ValueError(f"Required action key '{ACTION}.gripper' not found in transition")
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
raise ValueError(
|
||||
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
|
||||
)
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
|
||||
gripper_action = gripper_action - 1.0
|
||||
gripper_action *= self.clip_max
|
||||
act[f"{ACTION}.gripper"] = gripper_action
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
# Compute desired gripper velocity
|
||||
u = float(act.get(f"{ACTION}.gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
|
||||
new_act = dict(act)
|
||||
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{OBS_STATE}.gripper.pos"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
||||
"""
|
||||
Compute the end-effector pose from the joint positions.
|
||||
|
||||
Input OBSERVATION keys:
|
||||
{
|
||||
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
|
||||
}
|
||||
|
||||
Output OBSERVATION keys:
|
||||
{
|
||||
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, obs: dict) -> dict:
|
||||
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
|
||||
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
|
||||
|
||||
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
|
||||
obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
|
||||
obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
|
||||
obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
|
||||
obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
|
||||
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
|
||||
return obs
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[f"{OBS_STATE}.ee.{k}"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_robot_observation")
|
||||
@dataclass
|
||||
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Read the robot's current observation and insert it into the transition as complementary data.
|
||||
|
||||
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
|
||||
{ "<motor_name>": <float position>, ... }
|
||||
"""
|
||||
|
||||
robot: Robot
|
||||
|
||||
def complementary_data(self, comp: dict | None) -> dict:
|
||||
new_comp = dict(comp)
|
||||
obs = (
|
||||
self.robot.get_observation()
|
||||
) # todo(steven): why not self.trtansition.get(TransitionKey.OBSERVATION)?
|
||||
|
||||
new_comp["raw_joint_positions"] = {
|
||||
k.removesuffix(".pos"): float(v)
|
||||
for k, v in obs.items()
|
||||
if isinstance(k, str) and k.endswith(".pos")
|
||||
}
|
||||
return new_comp
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -1,200 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import make_cameras_from_configs
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
from lerobot.motors.feetech import FeetechMotorsBus
|
||||
|
||||
from . import SO100Follower
|
||||
from .config_so100_follower import SO100FollowerEndEffectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SO100FollowerEndEffector(SO100Follower):
|
||||
"""
|
||||
SO100Follower robot with end-effector space control.
|
||||
|
||||
This robot inherits from SO100Follower but transforms actions from
|
||||
end-effector space to joint space before sending them to the motors.
|
||||
"""
|
||||
|
||||
config_class = SO100FollowerEndEffectorConfig
|
||||
name = "so100_follower_end_effector"
|
||||
|
||||
def __init__(self, config: SO100FollowerEndEffectorConfig):
|
||||
super().__init__(config)
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.config = config
|
||||
|
||||
# Initialize the kinematics module for the so100 robot
|
||||
if self.config.urdf_path is None:
|
||||
raise ValueError(
|
||||
"urdf_path must be provided in the configuration for end-effector control. "
|
||||
"Please set urdf_path in your SO100FollowerEndEffectorConfig."
|
||||
)
|
||||
|
||||
self.kinematics = RobotKinematics(
|
||||
urdf_path=self.config.urdf_path,
|
||||
target_frame_name=self.config.target_frame_name,
|
||||
)
|
||||
|
||||
# Store the bounds for end-effector position
|
||||
self.end_effector_bounds = self.config.end_effector_bounds
|
||||
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, Any]:
|
||||
"""
|
||||
Define action features for end-effector control.
|
||||
Returns dictionary with dtype, shape, and names.
|
||||
"""
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform action from end-effector space to joint space and send to motors.
|
||||
|
||||
Args:
|
||||
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
|
||||
or a numpy array with [delta_x, delta_y, delta_z]
|
||||
|
||||
Returns:
|
||||
The joint-space action that was sent to the motors
|
||||
"""
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Convert action to numpy array if not already
|
||||
if isinstance(action, dict):
|
||||
if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
|
||||
delta_ee = np.array(
|
||||
[
|
||||
action["delta_x"] * self.config.end_effector_step_sizes["x"],
|
||||
action["delta_y"] * self.config.end_effector_step_sizes["y"],
|
||||
action["delta_z"] * self.config.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
if "gripper" not in action:
|
||||
action["gripper"] = [1.0]
|
||||
action = np.append(delta_ee, action["gripper"])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
|
||||
)
|
||||
action = np.zeros(4, dtype=np.float32)
|
||||
|
||||
if self.current_joint_pos is None:
|
||||
# Read current joint positions
|
||||
current_joint_pos = self.bus.sync_read("Present_Position")
|
||||
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
|
||||
|
||||
# Calculate current end-effector position using forward kinematics
|
||||
if self.current_ee_pos is None:
|
||||
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
|
||||
|
||||
# Set desired end-effector position by adding delta
|
||||
desired_ee_pos = np.eye(4)
|
||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||
|
||||
# Add delta to position and clip to bounds
|
||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
|
||||
if self.end_effector_bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3],
|
||||
self.end_effector_bounds["min"],
|
||||
self.end_effector_bounds["max"],
|
||||
)
|
||||
|
||||
# Compute inverse kinematics to get joint positions
|
||||
target_joint_values_in_degrees = self.kinematics.inverse_kinematics(
|
||||
self.current_joint_pos, desired_ee_pos
|
||||
)
|
||||
|
||||
# Create joint space action dictionary
|
||||
joint_action = {
|
||||
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
|
||||
}
|
||||
|
||||
# Handle gripper separately if included in action
|
||||
# Gripper delta action is in the range 0 - 2,
|
||||
# We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
|
||||
joint_action["gripper.pos"] = np.clip(
|
||||
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
||||
5,
|
||||
self.config.max_gripper_pos,
|
||||
)
|
||||
|
||||
self.current_ee_pos = desired_ee_pos.copy()
|
||||
self.current_joint_pos = target_joint_values_in_degrees.copy()
|
||||
self.current_joint_pos[-1] = joint_action["gripper.pos"]
|
||||
|
||||
# Send joint space action to parent class
|
||||
return super().send_action(joint_action)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def reset(self):
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
@@ -29,10 +29,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .so100_follower import SO100Follower
|
||||
|
||||
return SO100Follower(config)
|
||||
elif config.type == "so100_follower_end_effector":
|
||||
from .so100_follower import SO100FollowerEndEffector
|
||||
|
||||
return SO100FollowerEndEffector(config)
|
||||
elif config.type == "so101_follower":
|
||||
from .so101_follower import SO101Follower
|
||||
|
||||
@@ -69,6 +65,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
raise ValueError(config.type)
|
||||
|
||||
|
||||
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
|
||||
def ensure_safe_goal_position(
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||
) -> dict[str, float]:
|
||||
|
||||
@@ -62,9 +62,16 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.rl.gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
@@ -91,7 +98,6 @@ from lerobot.utils.utils import (
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
@@ -236,7 +242,8 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -257,6 +264,12 @@ def act_with_policy(
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -274,45 +287,71 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
observation = {
|
||||
k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features
|
||||
}
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# Use the new step function
|
||||
new_transition = step_env_and_process_transition(
|
||||
env=online_env,
|
||||
transition=transition,
|
||||
action=action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
next_observation = {
|
||||
k: v
|
||||
for k, v in new_transition[TransitionKey.OBSERVATION].items()
|
||||
if k in cfg.policy.input_features
|
||||
}
|
||||
|
||||
# Teleop action is the action that was executed in the environment
|
||||
# It is either the action from the teleop device or the action from the policy
|
||||
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
complementary_info = {
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
state=observation,
|
||||
action=executed_action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
next_state=next_observation,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
# Update transition for next iteration
|
||||
transition = new_transition
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
@@ -347,12 +386,20 @@ def act_with_policy(
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
# Reset intervention counters and environment
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -75,6 +75,7 @@ from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
MAX_MESSAGE_SIZE,
|
||||
@@ -1048,10 +1049,8 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
@@ -1176,7 +1175,7 @@ def process_transitions(
|
||||
|
||||
# Add to offline buffer if it's an intervention
|
||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
TeleopEvents.IS_INTERVENTION
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
|
||||
@@ -26,12 +26,13 @@ from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
@@ -140,6 +141,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
@@ -149,6 +153,12 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
preprocessor.from_pretrained(
|
||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
postprocessor.from_pretrained(
|
||||
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -203,12 +213,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
batch = preprocessor(batch)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
@@ -240,7 +247,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
save_checkpoint(
|
||||
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
@@ -284,6 +293,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
+77
-17
@@ -56,11 +56,17 @@ import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
|
||||
import draccus
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -97,39 +103,84 @@ class TeleoperateConfig:
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Optional processors for data transformation
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None # runs after teleop
|
||||
robot_action_processor: RobotProcessorPipeline | None = None # runs before robot
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None # runs after robot
|
||||
|
||||
|
||||
def teleop_loop(
|
||||
teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None
|
||||
teleop: Teleoperator,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
display_data: bool = False,
|
||||
duration: float | None = None,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None,
|
||||
):
|
||||
# Initialize processors with defaults if not provided
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=lambda tr: tr
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=transition_to_robot_action, # type: ignore[arg-type]
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Reset processors
|
||||
teleop_action_processor.reset()
|
||||
robot_action_processor.reset()
|
||||
robot_observation_processor.reset()
|
||||
|
||||
display_len = max(len(key) for key in robot.action_features)
|
||||
start = time.perf_counter()
|
||||
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
if display_data:
|
||||
observation = robot.get_observation()
|
||||
log_rerun_data(observation, action)
|
||||
|
||||
robot.send_action(action)
|
||||
# Get teleop action
|
||||
raw_action = teleop.get_action()
|
||||
|
||||
# Process teleop action through pipeline
|
||||
teleop_transition = teleop_action_processor(raw_action)
|
||||
|
||||
# Process action for robot through pipeline
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
|
||||
robot.send_action(robot_action_to_send) # type: ignore[arg-type]
|
||||
|
||||
if display_data:
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
# Process robot observation through pipeline
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
log_rerun_data([obs_transition, teleop_transition])
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
|
||||
# Display the final robot action that was sent
|
||||
for motor, value in robot_action_to_send.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
move_cursor_up(len(robot_action_to_send) + 5)
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
|
||||
for motor, value in action.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
|
||||
if duration is not None and time.perf_counter() - start >= duration:
|
||||
return
|
||||
|
||||
move_cursor_up(len(action) + 5)
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
@parser.wrap()
|
||||
def teleoperate(cfg: TeleoperateConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
@@ -143,7 +194,16 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s)
|
||||
teleop_loop(
|
||||
teleop=teleop,
|
||||
robot=robot,
|
||||
fps=cfg.fps,
|
||||
display_data=cfg.display_data,
|
||||
duration=cfg.teleop_time_s,
|
||||
teleop_action_processor=cfg.teleop_action_processor,
|
||||
robot_action_processor=cfg.robot_action_processor,
|
||||
robot_observation_processor=cfg.robot_observation_processor,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
from .utils import make_teleoperator_from_config
|
||||
from .utils import TeleopEvents, make_teleoperator_from_config
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
import logging
|
||||
|
||||
from ..utils import TeleopEvents
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
@@ -134,10 +136,10 @@ class KeyboardController(InputController):
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -255,13 +257,13 @@ class GamepadController(InputController):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
@@ -451,11 +453,11 @@ class GamepadControllerHID(InputController):
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
|
||||
|
||||
@@ -93,9 +94,9 @@ class GamepadTeleop(Teleoperator):
|
||||
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||
|
||||
action_dict = {
|
||||
"delta_x": gamepad_action[0],
|
||||
"delta_y": gamepad_action[1],
|
||||
"delta_z": gamepad_action[2],
|
||||
"action.delta_x": gamepad_action[0],
|
||||
"action.delta_y": gamepad_action[1],
|
||||
"action.delta_z": gamepad_action[2],
|
||||
}
|
||||
|
||||
# Default gripper action is to stay
|
||||
@@ -107,6 +108,48 @@ class GamepadTeleop(Teleoperator):
|
||||
|
||||
return action_dict
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the gamepad such as intervention status,
|
||||
episode termination, success indicators, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
if self.gamepad is None:
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: False,
|
||||
TeleopEvents.TERMINATE_EPISODE: False,
|
||||
TeleopEvents.SUCCESS: False,
|
||||
TeleopEvents.RERECORD_EPISODE: False,
|
||||
}
|
||||
|
||||
# Update gamepad state to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
# Check if intervention is active
|
||||
is_intervention = self.gamepad.should_intervene()
|
||||
|
||||
# Get episode end status
|
||||
episode_end_status = self.gamepad.get_episode_end_status()
|
||||
terminate_episode = episode_end_status in [
|
||||
TeleopEvents.RERECORD_EPISODE,
|
||||
TeleopEvents.FAILURE,
|
||||
]
|
||||
success = episode_end_status == TeleopEvents.SUCCESS
|
||||
rerecord_episode = episode_end_status == TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: is_intervention,
|
||||
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
|
||||
TeleopEvents.SUCCESS: success,
|
||||
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the gamepad."""
|
||||
if self.gamepad is not None:
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
|
||||
|
||||
PYNPUT_AVAILABLE = True
|
||||
@@ -167,25 +168,15 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2, "action.gripper": 3},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (3,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2},
|
||||
}
|
||||
|
||||
def _on_press(self, key):
|
||||
if hasattr(key, "char"):
|
||||
key = key.char
|
||||
self.event_queue.put((key, True))
|
||||
|
||||
def _on_release(self, key):
|
||||
if hasattr(key, "char"):
|
||||
key = key.char
|
||||
self.event_queue.put((key, False))
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
@@ -226,12 +217,75 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
self.current_pressed.clear()
|
||||
|
||||
action_dict = {
|
||||
"delta_x": delta_x,
|
||||
"delta_y": delta_y,
|
||||
"delta_z": delta_z,
|
||||
"action.delta_x": delta_x,
|
||||
"action.delta_y": delta_y,
|
||||
"action.delta_z": delta_z,
|
||||
}
|
||||
|
||||
if self.config.use_gripper:
|
||||
action_dict["gripper"] = gripper_action
|
||||
|
||||
return action_dict
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the keyboard such as intervention status,
|
||||
episode termination, success indicators, etc.
|
||||
|
||||
Keyboard mappings:
|
||||
- Any movement keys pressed = intervention active
|
||||
- 's' key = success (terminate episode successfully)
|
||||
- 'r' key = rerecord episode (terminate and rerecord)
|
||||
- 'q' key = quit episode (terminate without success)
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: False,
|
||||
TeleopEvents.TERMINATE_EPISODE: False,
|
||||
TeleopEvents.SUCCESS: False,
|
||||
TeleopEvents.RERECORD_EPISODE: False,
|
||||
}
|
||||
|
||||
# Check if any movement keys are currently pressed (indicates intervention)
|
||||
movement_keys = [
|
||||
keyboard.Key.up,
|
||||
keyboard.Key.down,
|
||||
keyboard.Key.left,
|
||||
keyboard.Key.right,
|
||||
keyboard.Key.shift,
|
||||
keyboard.Key.shift_r,
|
||||
keyboard.Key.ctrl_r,
|
||||
keyboard.Key.ctrl_l,
|
||||
]
|
||||
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
|
||||
|
||||
# Check for episode control commands from misc_keys_queue
|
||||
terminate_episode = False
|
||||
success = False
|
||||
rerecord_episode = False
|
||||
|
||||
# Process any pending misc keys
|
||||
while not self.misc_keys_queue.empty():
|
||||
key = self.misc_keys_queue.get_nowait()
|
||||
if key == "s":
|
||||
success = True
|
||||
elif key == "r":
|
||||
terminate_episode = True
|
||||
rerecord_episode = True
|
||||
elif key == "q":
|
||||
terminate_episode = True
|
||||
success = False
|
||||
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: is_intervention,
|
||||
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
|
||||
TeleopEvents.SUCCESS: success,
|
||||
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_phone import PhoneConfig
|
||||
from .phone import Phone
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
class PhoneOS(Enum):
|
||||
ANDROID = "android"
|
||||
IOS = "ios"
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("phone")
|
||||
@dataclass
|
||||
class PhoneConfig(TeleoperatorConfig):
|
||||
phone_os: PhoneOS = PhoneOS.IOS
|
||||
camera_offset = np.array(
|
||||
[0.0, -0.02, 0.04]
|
||||
) # iPhone 14 Pro camera is 2cm off center and 4cm above center
|
||||
@@ -0,0 +1,97 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.processor import ActionProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapPhoneActionToRobotAction(ActionProcessorStep):
|
||||
"""
|
||||
Map calibrated phone pose (actions) to the inputs for robot actions
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.phone.enabled": bool,
|
||||
"action.phone.pos": np.ndarray,
|
||||
"action.phone.rot": Rotation,
|
||||
"action.phone.raw_inputs": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
platform: PhoneOS
|
||||
_enabled_prev: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
# Pop them from the action
|
||||
enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0))
|
||||
pos = act.pop(f"{ACTION}.phone.pos", None)
|
||||
rot = act.pop(f"{ACTION}.phone.rot", None)
|
||||
inputs = act.pop(f"{ACTION}.phone.raw_inputs", {})
|
||||
|
||||
if pos is None or rot is None:
|
||||
raise ValueError("pos and rot must be present in action")
|
||||
|
||||
rotvec = rot.as_rotvec() # Absolute orientation as rotvec
|
||||
|
||||
# Map certain inputs to certain actions
|
||||
if self.platform == PhoneOS.IOS:
|
||||
gripper = float(inputs.get("a3", 0.0))
|
||||
else:
|
||||
a = float(inputs.get("reservedButtonA", 0.0))
|
||||
b = float(inputs.get("reservedButtonB", 0.0))
|
||||
gripper = (
|
||||
a - b
|
||||
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
|
||||
|
||||
# For some actions we need to invert the axis
|
||||
act[f"{ACTION}.enabled"] = enabled
|
||||
act[f"{ACTION}.target_x"] = -pos[1] if enabled else 0.0
|
||||
act[f"{ACTION}.target_y"] = pos[0] if enabled else 0.0
|
||||
act[f"{ACTION}.target_z"] = pos[2] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wx"] = rotvec[1] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wy"] = rotvec[0] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wz"] = -rotvec[2] if enabled else 0.0
|
||||
act[f"{ACTION}.gripper"] = gripper # Still send gripper action when disabled
|
||||
return act
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.phone.enabled", None)
|
||||
features.pop(f"{ACTION}.phone.pos", None)
|
||||
features.pop(f"{ACTION}.phone.rot", None)
|
||||
features.pop(f"{ACTION}.phone.raw_inputs", None)
|
||||
|
||||
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
@@ -0,0 +1,358 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Docs:
|
||||
# hebi: https://docs.hebi.us/tools.html#mobile-io
|
||||
# teleop: https://github.com/SpesRobotics/teleop
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePhone:
|
||||
_enabled: bool = False
|
||||
_calib_pos: np.ndarray | None = None
|
||||
_calib_rot_inv: Rotation | None = None
|
||||
|
||||
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
|
||||
self._calib_pos = pos.copy()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
"phone.pos": np.ndarray, # shape (3,)
|
||||
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
|
||||
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
|
||||
"phone.enabled": bool,
|
||||
}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
# No haptic or other feedback implemented yet
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
# No additional configuration required for phone teleop
|
||||
pass
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# We could add haptic feedback (vibrations) here, but it's not implemented yet
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IOSPhone(BasePhone, Teleoperator):
|
||||
name = "ios_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._group = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._group is not None
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
|
||||
lookup = hebi.Lookup()
|
||||
time.sleep(2.0)
|
||||
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
|
||||
if group is None:
|
||||
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
|
||||
self._group = group
|
||||
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
|
||||
position, rotation = self._wait_for_capture_trigger()
|
||||
self._calib_pos = position.copy()
|
||||
self._calib_rot_inv = rotation.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
has_pose, position, rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
io = getattr(fb_pose, "io", None)
|
||||
button_b = getattr(io, "b", None) if io is not None else None
|
||||
button_b1_pressed = False
|
||||
if button_b is not None:
|
||||
button_b1_pressed = bool(button_b.get_int(1))
|
||||
if button_b1_pressed:
|
||||
return position, rotation
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
fbk = self._group.get_next_feedback()
|
||||
pose = fbk[0]
|
||||
ar_pos = getattr(pose, "ar_position", None)
|
||||
ar_quat = getattr(pose, "ar_orientation", None)
|
||||
if ar_pos is None or ar_quat is None:
|
||||
return False, None, None, None
|
||||
# HEBI provides orientation in w, x, y, z format.
|
||||
# Scipy's Rotation expects x, y, z, w.
|
||||
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
|
||||
rot = Rotation.from_quat(quat_xyzw)
|
||||
pos = ar_pos - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
|
||||
def get_action(self) -> dict:
|
||||
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
io = getattr(fb_pose, "io", None)
|
||||
if io is not None:
|
||||
bank_a, bank_b = io.a, io.b
|
||||
if bank_a:
|
||||
for ch in range(1, 9):
|
||||
if bank_a.has_float(ch):
|
||||
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
|
||||
if bank_b:
|
||||
for ch in range(1, 9):
|
||||
if bank_b.has_int(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
|
||||
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
|
||||
|
||||
enable = bool(raw_inputs.get("b1", 0))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_position)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_position - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rotation
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._group = None
|
||||
|
||||
|
||||
class AndroidPhone(BasePhone, Teleoperator):
|
||||
name = "android_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._teleop = None
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
self._latest_message = None
|
||||
self._android_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._teleop is not None
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Starting teleop stream for Android...")
|
||||
self._teleop = Teleop()
|
||||
self._teleop.subscribe(self._android_callback)
|
||||
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
|
||||
self._teleop_thread.start()
|
||||
logger.info(f"{self} connected, teleop stream started.")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
print("Touch and move on the WebXR page to capture this pose...\n")
|
||||
|
||||
pos, rot = self._wait_for_capture_trigger()
|
||||
self._calib_pos = pos.copy()
|
||||
self._calib_rot_inv = rot.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
with self._android_lock:
|
||||
msg = self._latest_message or {}
|
||||
|
||||
if bool(msg.get("move", False)):
|
||||
ok, pos, rot, _pose = self._read_current_pose()
|
||||
if ok:
|
||||
return pos, rot
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
with self._android_lock:
|
||||
if self._latest_pose is None:
|
||||
return False, None, None, None
|
||||
p = self._latest_pose.copy()
|
||||
pose = self._latest_pose
|
||||
rot = Rotation.from_matrix(p[:3, :3])
|
||||
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
|
||||
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
|
||||
with self._android_lock:
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
msg = self._latest_message or {}
|
||||
raw_inputs["move"] = bool(msg.get("move", False))
|
||||
raw_inputs["scale"] = float(msg.get("scale", 1.0))
|
||||
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
|
||||
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
|
||||
|
||||
enable = bool(raw_inputs.get("move", False))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_pos)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rot
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._teleop = None
|
||||
if self._teleop_thread and self._teleop_thread.is_alive():
|
||||
self._teleop_thread.join(timeout=1.0)
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
|
||||
|
||||
class Phone(Teleoperator):
|
||||
"""
|
||||
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
|
||||
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
|
||||
|
||||
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
|
||||
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
|
||||
"""
|
||||
|
||||
config_class = PhoneConfig
|
||||
name = "phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self._phone_impl: Teleoperator
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
self._phone_impl = IOSPhone(config)
|
||||
elif self.config.phone_os == PhoneOS.ANDROID:
|
||||
self._phone_impl = AndroidPhone(config)
|
||||
else:
|
||||
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._phone_impl.is_connected
|
||||
|
||||
def connect(self) -> None:
|
||||
return self._phone_impl.connect()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
return self._phone_impl.calibrate()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._phone_impl.is_calibrated
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.action_features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.feedback_features
|
||||
|
||||
def configure(self) -> None:
|
||||
return self._phone_impl.configure()
|
||||
|
||||
def get_action(self) -> dict:
|
||||
return self._phone_impl.get_action()
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
return self._phone_impl.send_feedback(feedback)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return self._phone_impl.disconnect()
|
||||
@@ -12,10 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
|
||||
|
||||
class TeleopEvents(Enum):
|
||||
"""Shared constants for teleoperator events across teleoperators."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
RERECORD_EPISODE = "rerecord_episode"
|
||||
IS_INTERVENTION = "is_intervention"
|
||||
TERMINATE_EPISODE = "terminate_episode"
|
||||
|
||||
|
||||
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
if config.type == "keyboard":
|
||||
from .keyboard import KeyboardTeleop
|
||||
|
||||
@@ -31,6 +31,7 @@ from termcolor import colored
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyProcessorPipeline, TransitionKey
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@@ -101,6 +102,8 @@ def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
@@ -122,10 +125,14 @@ def predict_action(
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
observation = preprocessor(observation)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
|
||||
|
||||
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
||||
_transformers_available = is_package_available("transformers")
|
||||
_gym_xarm_available = is_package_available("gym_xarm")
|
||||
_gym_aloha_available = is_package_available("gym_aloha")
|
||||
_gym_pusht_available = is_package_available("gym_pusht")
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom rotation utilities to replace scipy.spatial.transform.Rotation."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Rotation:
|
||||
"""
|
||||
Custom rotation class that provides a subset of scipy.spatial.transform.Rotation functionality.
|
||||
|
||||
Supports conversions between rotation vectors, rotation matrices, and quaternions.
|
||||
"""
|
||||
|
||||
def __init__(self, quat: np.ndarray) -> None:
|
||||
"""Initialize rotation from quaternion [x, y, z, w]."""
|
||||
self._quat = np.asarray(quat, dtype=float)
|
||||
# Normalize quaternion
|
||||
norm = np.linalg.norm(self._quat)
|
||||
if norm > 0:
|
||||
self._quat = self._quat / norm
|
||||
|
||||
@classmethod
|
||||
def from_rotvec(cls, rotvec: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from rotation vector using Rodrigues' formula.
|
||||
|
||||
Args:
|
||||
rotvec: Rotation vector [x, y, z] where magnitude is angle in radians
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
rotvec = np.asarray(rotvec, dtype=float)
|
||||
angle = np.linalg.norm(rotvec)
|
||||
|
||||
if angle < 1e-8:
|
||||
# For very small angles, use identity quaternion
|
||||
quat = np.array([0.0, 0.0, 0.0, 1.0])
|
||||
else:
|
||||
axis = rotvec / angle
|
||||
half_angle = angle / 2.0
|
||||
sin_half = np.sin(half_angle)
|
||||
cos_half = np.cos(half_angle)
|
||||
|
||||
# Quaternion [x, y, z, w]
|
||||
quat = np.array([axis[0] * sin_half, axis[1] * sin_half, axis[2] * sin_half, cos_half])
|
||||
|
||||
return cls(quat)
|
||||
|
||||
@classmethod
|
||||
def from_matrix(cls, matrix: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from 3x3 rotation matrix.
|
||||
|
||||
Args:
|
||||
matrix: 3x3 rotation matrix
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
matrix = np.asarray(matrix, dtype=float)
|
||||
|
||||
# Shepherd's method for converting rotation matrix to quaternion
|
||||
trace = np.trace(matrix)
|
||||
|
||||
if trace > 0:
|
||||
s = np.sqrt(trace + 1.0) * 2 # s = 4 * qw
|
||||
qw = 0.25 * s
|
||||
qx = (matrix[2, 1] - matrix[1, 2]) / s
|
||||
qy = (matrix[0, 2] - matrix[2, 0]) / s
|
||||
qz = (matrix[1, 0] - matrix[0, 1]) / s
|
||||
elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]:
|
||||
s = np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 # s = 4 * qx
|
||||
qw = (matrix[2, 1] - matrix[1, 2]) / s
|
||||
qx = 0.25 * s
|
||||
qy = (matrix[0, 1] + matrix[1, 0]) / s
|
||||
qz = (matrix[0, 2] + matrix[2, 0]) / s
|
||||
elif matrix[1, 1] > matrix[2, 2]:
|
||||
s = np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 # s = 4 * qy
|
||||
qw = (matrix[0, 2] - matrix[2, 0]) / s
|
||||
qx = (matrix[0, 1] + matrix[1, 0]) / s
|
||||
qy = 0.25 * s
|
||||
qz = (matrix[1, 2] + matrix[2, 1]) / s
|
||||
else:
|
||||
s = np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 # s = 4 * qz
|
||||
qw = (matrix[1, 0] - matrix[0, 1]) / s
|
||||
qx = (matrix[0, 2] + matrix[2, 0]) / s
|
||||
qy = (matrix[1, 2] + matrix[2, 1]) / s
|
||||
qz = 0.25 * s
|
||||
|
||||
quat = np.array([qx, qy, qz, qw])
|
||||
return cls(quat)
|
||||
|
||||
@classmethod
|
||||
def from_quat(cls, quat: np.ndarray) -> "Rotation":
|
||||
"""
|
||||
Create rotation from quaternion.
|
||||
|
||||
Args:
|
||||
quat: Quaternion [x, y, z, w] or [w, x, y, z] (specify convention in docstring)
|
||||
This implementation expects [x, y, z, w] format
|
||||
|
||||
Returns:
|
||||
Rotation instance
|
||||
"""
|
||||
return cls(quat)
|
||||
|
||||
def as_matrix(self) -> np.ndarray:
|
||||
"""
|
||||
Convert rotation to 3x3 rotation matrix.
|
||||
|
||||
Returns:
|
||||
3x3 rotation matrix
|
||||
"""
|
||||
qx, qy, qz, qw = self._quat
|
||||
|
||||
# Compute rotation matrix from quaternion
|
||||
return np.array(
|
||||
[
|
||||
[1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)],
|
||||
[2 * (qx * qy + qz * qw), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qx * qw)],
|
||||
[2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx * qx + qy * qy)],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
def as_rotvec(self) -> np.ndarray:
|
||||
"""
|
||||
Convert rotation to rotation vector.
|
||||
|
||||
Returns:
|
||||
Rotation vector [x, y, z] where magnitude is angle in radians
|
||||
"""
|
||||
qx, qy, qz, qw = self._quat
|
||||
|
||||
# Ensure qw is positive for unique representation
|
||||
if qw < 0:
|
||||
qx, qy, qz, qw = -qx, -qy, -qz, -qw
|
||||
|
||||
# Compute angle and axis
|
||||
angle = 2.0 * np.arccos(np.clip(abs(qw), 0.0, 1.0))
|
||||
sin_half_angle = np.sqrt(1.0 - qw * qw)
|
||||
|
||||
if sin_half_angle < 1e-8:
|
||||
# For very small angles, use linearization: rotvec ≈ 2 * [qx, qy, qz]
|
||||
return 2.0 * np.array([qx, qy, qz])
|
||||
|
||||
# Extract axis and scale by angle
|
||||
axis = np.array([qx, qy, qz]) / sin_half_angle
|
||||
return angle * axis
|
||||
|
||||
def as_quat(self) -> np.ndarray:
|
||||
"""
|
||||
Get quaternion representation.
|
||||
|
||||
Returns:
|
||||
Quaternion [x, y, z, w]
|
||||
"""
|
||||
return self._quat.copy()
|
||||
@@ -32,6 +32,7 @@ from lerobot.datasets.utils import load_json, write_json
|
||||
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
|
||||
|
||||
@@ -74,6 +75,8 @@ def save_checkpoint(
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -81,7 +84,9 @@ def save_checkpoint(
|
||||
├── pretrained_model/
|
||||
│ ├── config.json # policy config
|
||||
│ ├── model.safetensors # policy weights
|
||||
│ └── train_config.json # train config
|
||||
│ ├── train_config.json # train config
|
||||
│ ├── processor.json # processor config (if preprocessor provided)
|
||||
│ └── step_*.safetensors # processor state files (if any)
|
||||
└── training_state/
|
||||
├── optimizer_param_groups.json # optimizer param groups
|
||||
├── optimizer_state.safetensors # optimizer state
|
||||
@@ -95,10 +100,15 @@ def save_checkpoint(
|
||||
policy (PreTrainedPolicy): The policy to save.
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
if preprocessor is not None:
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
|
||||
@@ -12,12 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.processor import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop."""
|
||||
@@ -28,19 +31,87 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]):
|
||||
for obs, val in observation.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"observation.{obs}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
if val.ndim == 1:
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v)))
|
||||
def _is_scalar(x):
|
||||
return (
|
||||
isinstance(x, numbers.Real)
|
||||
or isinstance(x, (np.integer, np.floating))
|
||||
or (isinstance(x, np.ndarray) and x.ndim == 0)
|
||||
)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None,
|
||||
*,
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
items = data if isinstance(data, list) else ([data] if data is not None else [])
|
||||
|
||||
obs = {} if observation is None else dict(observation)
|
||||
act = {} if action is None else dict(action)
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if any(isinstance(k, TransitionKey) for k in item.keys()):
|
||||
o = item.get(TransitionKey.OBSERVATION) or {}
|
||||
a = item.get(TransitionKey.ACTION) or {}
|
||||
if isinstance(o, dict):
|
||||
obs.update(o)
|
||||
if isinstance(a, dict):
|
||||
act.update(a)
|
||||
continue
|
||||
|
||||
keys = list(item.keys())
|
||||
has_obs = any(str(k).startswith("observation.") for k in keys)
|
||||
has_act = any(str(k).startswith("action.") for k in keys)
|
||||
|
||||
if has_obs or has_act:
|
||||
if has_obs:
|
||||
obs.update(item)
|
||||
if has_act:
|
||||
act.update(item)
|
||||
else:
|
||||
# No prefixes: assume first is observation, second is action, others are observation
|
||||
if idx == 0:
|
||||
obs.update(item)
|
||||
elif idx == 1:
|
||||
act.update(item)
|
||||
else:
|
||||
rr.log(f"observation.{obs}", rr.Image(val), static=True)
|
||||
for act, val in action.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action.{act}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
|
||||
obs.update(item)
|
||||
|
||||
for k, v in obs.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
rr.log(key, rr.Image(arr), static=True)
|
||||
|
||||
for k, v in act.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("action.") else f"action.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
|
||||
oid sha256:ee0c29d3782aa1cadcf4dc6ed767d9460ff00fff9fc70b460502340b832eefcc
|
||||
size 5104
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
|
||||
size 33400
|
||||
oid sha256:ea76e6711959fd3f905ec2bdc306f488920f00ec99421e4870d05f6205eb323e
|
||||
size 31672
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
|
||||
oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c
|
||||
size 515400
|
||||
|
||||
+2
-2
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
|
||||
size 33400
|
||||
oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
|
||||
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
|
||||
size 992
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
|
||||
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
|
||||
size 47424
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
|
||||
size 49120
|
||||
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
|
||||
size 47424
|
||||
|
||||
@@ -23,7 +23,8 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
|
||||
@@ -37,7 +38,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
dataset = make_dataset(train_cfg)
|
||||
dataset_stats = dataset.meta.stats
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats)
|
||||
policy.train()
|
||||
|
||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||
@@ -49,7 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
batch = preprocessor(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
if output_dict is not None:
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
@@ -96,7 +101,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
else:
|
||||
actions_queue = train_cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
actions = {}
|
||||
for i in range(actions_queue):
|
||||
unnormalized_action = policy.select_action(obs).contiguous()
|
||||
action_robot = postprocessor({TransitionKey.ACTION: unnormalized_action}).get(TransitionKey.ACTION)
|
||||
actions[str(i)] = action_robot
|
||||
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
|
||||
oid sha256:d640988f2269cf6aa03c8ee17f9d096edace83d837f90025011fafec5bf53c61
|
||||
size 200
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
||||
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
|
||||
size 16904
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
||||
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
|
||||
size 164
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
||||
size 36312
|
||||
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
|
||||
size 35496
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
|
||||
oid sha256:2212ae7b910d14d723214f5af50985e419f7bd0f4261565ef48b1ef495443d6d
|
||||
size 200
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
|
||||
oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a
|
||||
size 16904
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
|
||||
oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703
|
||||
size 164
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
|
||||
size 36312
|
||||
oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184
|
||||
size 35496
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_merge_simple_vectors():
|
||||
g1 = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.x", "ee.y"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.y", "ee.z"],
|
||||
}
|
||||
}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
|
||||
assert "action" in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
# Names merged with preserved order and de-dupuplication
|
||||
assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"]
|
||||
# Shape correctly recomputed from names length
|
||||
assert out["action"]["shape"] == (3,)
|
||||
|
||||
|
||||
def test_merge_multiple_groups_order_and_dedup():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
|
||||
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
|
||||
out = combine_feature_dicts(g1, g2, g3)
|
||||
|
||||
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
||||
assert out["action"]["shape"] == (4,)
|
||||
|
||||
|
||||
def test_non_vector_last_wins_for_images():
|
||||
# Non-vector (images) with same name should be overwritten by the last image specified
|
||||
g1 = {
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 480, 640),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 720, 1280),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||
assert out["observation.images.front"]["dtype"] == "image"
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}}
|
||||
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
|
||||
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
||||
_ = combine_feature_dicts(g1, g2)
|
||||
|
||||
|
||||
def test_non_dict_passthrough_last_wins():
|
||||
g1 = {"misc": 123}
|
||||
g2 = {"misc": 456}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
# For non-dict entries the last one wins
|
||||
assert out["misc"] == 456
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
@@ -26,7 +26,7 @@ from safetensors.torch import load_file
|
||||
from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
@@ -39,8 +39,8 @@ from lerobot.policies.factory import (
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
make_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
@@ -151,6 +151,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(train_cfg)
|
||||
preprocessor, _ = make_pre_post_processors(train_cfg.policy, None)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
assert isinstance(policy, PreTrainedPolicy)
|
||||
|
||||
@@ -224,6 +225,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.policy.optimizer_lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
preprocessor, _ = make_pre_post_processors(cfg.policy, None)
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
@@ -263,108 +265,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
an exception when the forward pass is called without the stats having been provided.
|
||||
|
||||
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
||||
expected.
|
||||
"""
|
||||
|
||||
input_features = {
|
||||
"observation.image": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 96, 96),
|
||||
),
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
}
|
||||
output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
"min": torch.randn(3, 1, 1),
|
||||
"max": torch.randn(3, 1, 1),
|
||||
},
|
||||
"observation.state": {
|
||||
"mean": torch.randn(10),
|
||||
"std": torch.randn(10),
|
||||
"min": torch.randn(10),
|
||||
"max": torch.randn(10),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.randn(5),
|
||||
"std": torch.randn(5),
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
}
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
tdim = 4
|
||||
|
||||
for key in input_batch:
|
||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||
|
||||
for key in output_batch:
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_features, norm_map, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test without stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
@@ -464,6 +364,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
NOTE: If the test don't pass and you don't change the policy, and note the dependencies version,
|
||||
and you changed your processor, you might have to update the test artifact.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -0,0 +1,355 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for ACT policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default ACT configuration for testing."""
|
||||
config = ACTConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
ACTION: {"mean": torch.zeros(4), "std": torch.ones(4)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_act_processor_basic():
|
||||
"""Test basic creation of ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_act_processor_normalization():
|
||||
"""Test that ACT processor correctly normalizes and unnormalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is normalized and batched
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized
|
||||
assert postprocessed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_cuda():
|
||||
"""Test ACT processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_accelerate_scenario():
|
||||
"""Test ACT processor in simulated Accelerate scenario (data already on GPU)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU (not moved unnecessarily)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_act_processor_multi_gpu():
|
||||
"""Test ACT processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
|
||||
# Simulate data on different GPU (like in multi-GPU training)
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1 (not moved to cuda:0)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_act_processor_without_stats():
|
||||
"""Test ACT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors, but normalization won't have stats
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work (but won't normalize without stats)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_act_processor_save_and_load():
|
||||
"""Test saving and loading ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
def test_act_processor_device_placement_preservation():
|
||||
"""Test that ACT processor preserves device placement correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
# Test with CPU config
|
||||
config.device = "cpu"
|
||||
preprocessor, _ = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_mixed_precision():
|
||||
"""Test ACT processor with mixed precision (float16)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Modify the device processor to use float16
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_act_processor_batch_consistency():
|
||||
"""Test that ACT processor handles different batch sizes correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
|
||||
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
action_batched = torch.randn(8, 4)
|
||||
transition_batched = create_transition(observation_batched, action_batched)
|
||||
|
||||
processed_batched = preprocessor(transition_batched)
|
||||
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
|
||||
assert processed_batched[TransitionKey.ACTION].shape[0] == 8
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user